7 - loan: has personal loan? (categorical: 'no','yes','unknown')
#### related with the last contact of the current campaign:
8 - contact: contact communication type (categorical: 'cellular','telephone')
11 - duration: last contact duration, in seconds (numeric). Important note: this attribute highly affects the output target (e.g., if duration=0 then y='no'). Yet, the duration is not known before a call is performed. Also, after the end of the call y is obviously known. Thus, this input should only be included for benchmark purposes and should be discarded if the intention is to have a realistic predictive model.
#### other attributes:
15 - poutcome: outcome of the previous marketing campaign (categorical: 'failure','nonexistent','success')
#### social and economic context attributes
Output variable (desired target):
# num_instances = 0
# test_ratio = 0.3
# test_size = int(num_instances * 0.3)
# train_size = num_instances - test_size
# #n_iter = 40
We define below all the libraries which will be utilized along the coursework
import plotly.offline as pyo
import plotly.graph_objs as go
# Set notebook mode to work in offline
pyo.init_notebook_mode()
!pip install scikit-learn
#!pip install six
!pip install imbalanced-learn
!pip install category_encoders
!pip install pytorch_lightning
!pip install skorch
!pip install torchviz
#!pip install collections
!pip install -U git+https://github.com/scikit-learn-contrib/imbalanced-learn.git
# Basic libraries for data exploration and preprocessing
import six
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from itertools import combinations
from scipy.stats import chi2_contingency
import category_encoders as ce
from sklearn import preprocessing
from sklearn.preprocessing import LabelEncoder, OrdinalEncoder, OneHotEncoder, StandardScaler, MinMaxScaler, RobustScaler
import matplotlib.ticker as mticker
from matplotlib.colors import Normalize
#from google.colab import files
# Library to measure time between executions
import time
from datetime import datetime
import math
from math import sqrt
from itertools import zip_longest
# Import basic calculations
from numpy import sqrt
from numpy import argmax
# Libraries to apply metric scores and support the model selection
from sklearn.metrics import confusion_matrix, precision_recall_curve,auc, roc_auc_score, roc_curve, recall_score, \
classification_report,accuracy_score, fbeta_score, make_scorer, average_precision_score, precision_score, f1_score
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold, GridSearchCV, RandomizedSearchCV, learning_curve
# Library to apply oversampling algorithm and plot imbalanced dataset using SMOTE
#from collections import Counter
from imblearn.over_sampling import SMOTE, BorderlineSMOTE, ADASYN
from imblearn.combine import SMOTETomek
from imblearn.under_sampling import RandomUnderSampler, OneSidedSelection
from scipy.stats import loguniform, randint, uniform
# Support Vector Machine Imports
from sklearn import datasets
from sklearn import metrics
from sklearn.svm import SVC
# Pytorch related libraries to use with MLP implementation
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn.functional as F
from torch import optim # PyTorch optimizer
#from pytorch_lightning.metrics import Accuracy
from torch.optim import Optimizer
from pytorch_lightning import metrics
# Library to implement Pipeline
from sklearn.pipeline import Pipeline
from imblearn.pipeline import make_pipeline
from imblearn.pipeline import Pipeline as imbPipeline
# Import from skorch
from skorch.callbacks import EpochScoring, BatchScoring, EarlyStopping, Callback, Checkpoint
from skorch import NeuralNetClassifier, NeuralNetBinaryClassifier
#from skorch.history import History
# Support Vector Machine Imports
from sklearn import datasets
from sklearn import metrics
from sklearn.svm import LinearSVC
from sklearn.model_selection import GridSearchCV
# Scikitplot library to plot Confusion Matrix, ROC and Precision x Recall Curves
#import scikitplot as skplt
# Visualize NN architecture
from torchviz import make_dot
# Import pickle Package to save best learning models (MLP and SVM)
import pickle as pkl
from joblib import dump, load
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
# import warnings filter
from warnings import simplefilter
# ignore all future warnings
simplefilter(action='ignore', category=FutureWarning)
/usr/local/lib/python3.7/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
class MidpointNormalize(Normalize):
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
self.midpoint = midpoint
Normalize.__init__(self, vmin, vmax, clip)
def __call__(self, value, clip=None):
x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
return np.ma.masked_array(np.interp(value, x, y))
def div0( a, b ):
""" ignore / 0, div0( [-1, 0, 1], 0 ) -> [0, 0, 0] """
with np.errstate(divide='ignore', invalid='ignore'):
c = np.true_divide( a, b )
c[~np.isfinite( c )] = 0 # -inf inf NaN
return c
def plot_confusion_matrix2(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=0)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#print("Normalized confusion matrix")
else:
1#print('Confusion matrix, without normalization')
#print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
def plot_confusion_matrix(cf_matrix, target_names=None,plot_type='',savefig='No'):
group_names = ['TN', 'FP', 'FN', 'TP']
group_counts = ["{0:.0f}".format(value) for value in
cf_matrix.flatten()]
group_percentages = ["{0:.2%}".format(value) for value in
cf_matrix.flatten()/np.sum(cf_matrix)]
labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in
zip(group_names,
group_counts,
group_percentages)
]
labels = np.asarray(labels).reshape(2, 2)
fig = plt.figure(figsize = (14, 8))
res = sns.heatmap(cf_matrix, annot=labels, fmt='', cmap='Blues', annot_kws={"size": 14})
if target_names:
tick_marks = range(len(target_names))
plt.xticks(tick_marks, target_names)
plt.yticks(tick_marks, target_names)
plt.xticks([0.5,1.5], target_names,va='center')
plt.yticks([0.5,1.5], target_names,va='center')
precision = cf_matrix[1, 1] / sum(cf_matrix[:, 1])
recall = cf_matrix[1, 1] / sum(cf_matrix[1,:])
accuracy = np.trace(cf_matrix) / float(np.sum(cf_matrix))
f1_score = 2 * precision * recall / (precision + recall)
stats_text = "\n\n\nPrecision.={:0.3f}\nRecall={:0.3f}\n\nAccuracy={:0.3f}\nF1 Score={:0.3f}".format(
precision, recall, accuracy, f1_score)
plt.xlabel('Predicted label {}'.format(stats_text),fontsize=12)
plt.ylabel("True Label", fontsize=12)
plt.show()
if savefig=='Yes':
fig.savefig('plot_confusion_matrix_' + plot_type + '_' + datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + '.png')
def plot_pr_auc_thresholds(precision, recall, thresh, y_test, y_pred_proba, plot_type='',savefig='No'):
pr_df = pd.DataFrame(zip_longest(precision, recall, thresh),columns = ["Precision","Recall","Threshold"])
# convert to f score
fscore = div0((2 * precision * recall), (precision + recall))
# locate the index of the largest f score
ix = np.argmax(fscore)
fig = px.area(pr_df, x='Precision',y='Recall',hover_data=['Threshold'],title='PR-AUC (' + plot_type + ')')
no_skill = len(y_test[y_test==1]) / len(y_test)
fig.add_shape(type='line',
x0=0,
y0=no_skill,
x1=1,
y1=no_skill,
line=dict(color='Black',dash= 'dash'),
xref='x',
yref='y'
)
fig.add_trace(go.Scatter(
x=[0.5],
y=[no_skill+0.05],
mode="text",
text=["Random Classifier:(P:N = 1:9)"],
textposition="bottom center",
textfont=dict(
family="Arial Black",
size=14,
color="Black"
)
))
for i in range(len(pr_df)):
if i%(len(pr_df)/20)==0:
fig.add_annotation(x=pr_df.iloc[i][0],y=pr_df.iloc[i][1],
text=f'Thresh={pr_df.iloc[i][2]:.2f}',
showarrow=True,arrowhead=1)
fig.add_annotation(x=pr_df.iloc[ix][0],y=pr_df.iloc[ix][1],
text=f'Optimal Thresh={fscore[ix]:.2f}',
showarrow=True,arrowcolor='red',
bgcolor='white',bordercolor='black',borderwidth=1,
font=dict(size=16,color="#FF0000"))
fig.add_annotation(x=0.5,y=0.5,text=f'AUC={average_precision_score(y_test, y_pred_proba):.2f}',
showarrow=False,font=dict(size=20,color="#FF0000",family="Arial Black"))
fig.layout.update(showlegend=False)
if savefig=='Yes':
fig.savefig('plot_precision_recall' + plot_type + '_' + datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + '.png')
fig.show()
def precision_recall_threshold(p, r, thresholds, y_test, y_pred_proba, t=0.5, plot_type='',savefig='No'):
"""
plots the precision recall curve and shows the current value for each
by identifying the classifier's threshold (t).
"""
font = {'family': 'Arial',
'color': 'red',
'weight': 'normal',
'size': 18,
}
# convert to f score
fscore = div0((2 * precision * recall), (precision + recall))
# locate the index of the largest f score
ix = np.argmax(fscore)
# plot the curve
plt.figure(figsize=(10,10))
plt.title('PR-AUC (' + plot_type + ')')
plt.step(r, p, color='b', alpha=0.2,
where='post')
no_skill = len(y_test[y_test==1]) / len(y_test)
plt.plot([0,1], [no_skill,no_skill], linestyle='--', color='black')
plt.fill_between(r, p, step='post', alpha=0.2,
color='b')
plt.text(0.0, no_skill + 0.03, 'Random Classifier:(P:N = 1:9)', color='black',fontsize=16)
plt.text(0.45, 0.3, 'AUC = %s'% (round(average_precision_score(y_test, y_pred_proba),2)), fontdict=font)
plt.text(r[ix], p[ix]+0.03,'Optimized Threshold', color='red',fontsize=14)
plt.scatter(r[ix], p[ix], marker='o', color='red')
plt.ylim([-0.05, 1.05]);
plt.xlim([-0.05, 1.05]);
plt.xlabel('Recall',fontsize=16);
plt.ylabel('Precision',fontsize=16);
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# plot the current threshold on the line
close_default_clf = np.argmin(np.abs(thresholds - t))
plt.plot(r[close_default_clf], p[close_default_clf], '^', c='k',
markersize=15)
class Dataset(torch.utils.data.Dataset):
#'Characterizes a dataset for PyTorch'
def __init__(self, features, labels):
'Initialization'
self.labels = labels
self.features = features
def __len__(self):
'Denotes the total number of samples'
return len(self.features)
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
ID = self.features[index]
# Load data and get label
X = torch.load(ID)
y = self.labels[ID]
return X, y
def plot_roc_auc_thresholds(fpr, tpr, thresh, y_test, y_pred_proba, plot_type='',savefig='No'):
roc_df = pd.DataFrame(zip_longest(fpr, tpr, thresh),columns = ["FPR","TPR","Threshold"])
# calculate the g-mean for each threshold
gmeans = np.sqrt(tpr * (1-fpr))
# locate the index of the largest g-mean
ix = np.argmax(gmeans)
fig = px.area(roc_df, x='FPR',y='TPR',hover_data=['Threshold'],title='ROC-AUC (' + plot_type + ')')
no_skill = len(y_test[y_test==1]) / len(y_test)
fig.add_shape(type='line',
x0=0,
y0=0,
x1=1,
y1=1,
line=dict(color='Black',dash= 'dash'),
xref='x',
yref='y'
)
fig.add_trace(go.Scatter(
x=[0.25],
y=[0.1],
mode="text",
text=["Random Classifier:(P:N = 1:9)"],
textposition="bottom center",
textfont=dict(
family="Arial Black",
size=14,
color="Black"
)
))
for i in range(len(roc_df)):
if i%(len(roc_df)/20)==0:
fig.add_annotation(x=roc_df.iloc[i][0],y=roc_df.iloc[i][1],
text=f'Thresh={roc_df.iloc[i][2]:.2f}',
showarrow=True,arrowhead=1)
fig.add_annotation(x=roc_df.iloc[ix][0],y=roc_df.iloc[ix][1],
text=f'Optimal Thresh={gmeans[ix]:.2f}',
showarrow=True,arrowcolor='red',
bgcolor='white',bordercolor='black',borderwidth=1,
font=dict(size=16,color="#FF0000"))
fig.add_annotation(x=0.5,y=0.5,text=f'AUC={roc_auc_score(y_test, y_pred_proba):.2f}',
showarrow=False,font=dict(size=20,color="#FF0000",family="Arial Black"))
fig.layout.update(showlegend=False)
if savefig=='Yes':
fig.savefig('plot_roc_auc_' + plot_type + '_' + datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + '.png')
fig.show()
def plot_roc_curve(fpr, tpr, label=None):
"""
The ROC curve, modified from
Hands-On Machine learning with Scikit-Learn and TensorFlow; p.91
"""
plt.figure(figsize=(8,8))
plt.title('ROC Curve')
plt.plot(fpr, tpr, linewidth=2, label=label)
plt.plot([0, 1], [0, 1], 'k--')
plt.axis([-0.005, 1, 0, 1.005])
plt.xticks(np.arange(0,1, 0.05), rotation=90)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate (Recall)")
plt.legend(loc='best')
Define AdaBound algorithm Class to use as one of the Neural Networks optimizers.
class AdaBound(Optimizer):
""" AdaBound code from https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py
Implements AdaBound algorithm.
It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): Adam learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
final_lr (float, optional): final (SGD) learning rate (default: 0.1)
gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
.. Adaptive Gradient Methods with Dynamic Bound of Learning Rate:
https://openreview.net/forum?id=Bkg3g2R9FX
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3,
eps=1e-8, weight_decay=0, amsbound=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= final_lr:
raise ValueError("Invalid final learning rate: {}".format(final_lr))
if not 0.0 <= gamma < 1.0:
raise ValueError("Invalid gamma parameter: {}".format(gamma))
defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps,
weight_decay=weight_decay, amsbound=amsbound)
super(AdaBound, self).__init__(params, defaults)
self.base_lrs = list(map(lambda group: group['lr'], self.param_groups))
def __setstate__(self, state):
super(AdaBound, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsbound', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group, base_lr in zip(self.param_groups, self.base_lrs):
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'Adam does not support sparse gradients, please consider SparseAdam instead')
amsbound = group['amsbound']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsbound:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsbound:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsbound:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
# Applies bounds on actual learning rate
# lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
final_lr = group['final_lr'] * group['lr'] / base_lr
lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))
upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step']))
step_size = torch.full_like(denom, step_size)
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
p.data.add_(-step_size)
return loss
#pd.options.display.max_rows
#pd.get_option("display.max_columns") #Default (20)
pd.set_option('display.max_columns',70)
Now we will load the Bank Marketing dataset that will be used for the analysis and critical evaluation between SVM and MLP algorithms.
# Load bank-additional-full.csv file
df = pd.read_csv('bank-additional-full.csv', delimiter=';')
df.head()
| age | job | marital | education | default | housing | loan | contact | month | day_of_week | duration | campaign | pdays | previous | poutcome | emp.var.rate | cons.price.idx | cons.conf.idx | euribor3m | nr.employed | y | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 56 | housemaid | married | basic.4y | no | no | no | telephone | may | mon | 261 | 1 | 999 | 0 | nonexistent | 1.1 | 93.994 | -36.4 | 4.857 | 5191.0 | no |
| 1 | 57 | services | married | high.school | unknown | no | no | telephone | may | mon | 149 | 1 | 999 | 0 | nonexistent | 1.1 | 93.994 | -36.4 | 4.857 | 5191.0 | no |
| 2 | 37 | services | married | high.school | no | yes | no | telephone | may | mon | 226 | 1 | 999 | 0 | nonexistent | 1.1 | 93.994 | -36.4 | 4.857 | 5191.0 | no |
| 3 | 40 | admin. | married | basic.6y | no | no | no | telephone | may | mon | 151 | 1 | 999 | 0 | nonexistent | 1.1 | 93.994 | -36.4 | 4.857 | 5191.0 | no |
| 4 | 56 | services | married | high.school | no | no | yes | telephone | may | mon | 307 | 1 | 999 | 0 | nonexistent | 1.1 | 93.994 | -36.4 | 4.857 | 5191.0 | no |
df.describe()
| age | duration | campaign | pdays | previous | emp.var.rate | cons.price.idx | cons.conf.idx | euribor3m | nr.employed | |
|---|---|---|---|---|---|---|---|---|---|---|
| count | 41188.00000 | 41188.000000 | 41188.000000 | 41188.000000 | 41188.000000 | 41188.000000 | 41188.000000 | 41188.000000 | 41188.000000 | 41188.000000 |
| mean | 40.02406 | 258.285010 | 2.567593 | 962.475454 | 0.172963 | 0.081886 | 93.575664 | -40.502600 | 3.621291 | 5167.035911 |
| std | 10.42125 | 259.279249 | 2.770014 | 186.910907 | 0.494901 | 1.570960 | 0.578840 | 4.628198 | 1.734447 | 72.251528 |
| min | 17.00000 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | -3.400000 | 92.201000 | -50.800000 | 0.634000 | 4963.600000 |
| 25% | 32.00000 | 102.000000 | 1.000000 | 999.000000 | 0.000000 | -1.800000 | 93.075000 | -42.700000 | 1.344000 | 5099.100000 |
| 50% | 38.00000 | 180.000000 | 2.000000 | 999.000000 | 0.000000 | 1.100000 | 93.749000 | -41.800000 | 4.857000 | 5191.000000 |
| 75% | 47.00000 | 319.000000 | 3.000000 | 999.000000 | 0.000000 | 1.400000 | 93.994000 | -36.400000 | 4.961000 | 5228.100000 |
| max | 98.00000 | 4918.000000 | 56.000000 | 999.000000 | 7.000000 | 1.400000 | 94.767000 | -26.900000 | 5.045000 | 5228.100000 |
df.nunique()
age 78 job 12 marital 4 education 8 default 3 housing 3 loan 3 contact 2 month 10 day_of_week 5 duration 1544 campaign 42 pdays 27 previous 8 poutcome 3 emp.var.rate 10 cons.price.idx 26 cons.conf.idx 26 euribor3m 316 nr.employed 11 y 2 dtype: int64
As stated in the dataset description the Duration column is not relevant for prediction purposes, since it is highly correlated to the target variable. So we will drop this column beforehand.
df.drop(columns=['duration'],inplace=True,errors='ignore')
And we count again the unique values after dropping the variable.
df.nunique()
age 78 job 12 marital 4 education 8 default 3 housing 3 loan 3 contact 2 month 10 day_of_week 5 campaign 42 pdays 27 previous 8 poutcome 3 emp.var.rate 10 cons.price.idx 26 cons.conf.idx 26 euribor3m 316 nr.employed 11 y 2 dtype: int64
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 41188 entries, 0 to 41187 Data columns (total 20 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 age 41188 non-null int64 1 job 41188 non-null object 2 marital 41188 non-null object 3 education 41188 non-null object 4 default 41188 non-null object 5 housing 41188 non-null object 6 loan 41188 non-null object 7 contact 41188 non-null object 8 month 41188 non-null object 9 day_of_week 41188 non-null object 10 campaign 41188 non-null int64 11 pdays 41188 non-null int64 12 previous 41188 non-null int64 13 poutcome 41188 non-null object 14 emp.var.rate 41188 non-null float64 15 cons.price.idx 41188 non-null float64 16 cons.conf.idx 41188 non-null float64 17 euribor3m 41188 non-null float64 18 nr.employed 41188 non-null float64 19 y 41188 non-null object dtypes: float64(5), int64(4), object(11) memory usage: 6.3+ MB
pd.value_counts(df['y']).plot.bar()
plt.title('Subscription of term deposit (Classes) - Histogram')
plt.xlabel('Class')
plt.ylabel('Frequency')
df['y'].value_counts()
print('% of Subscription of term deposit (No)', round(len(df[df.y=='no']) / len(df.index) * 100,2) )
print('% of Subscription of term deposit (Yes)', round(len(df[df.y=='yes']) / len(df.index) * 100,2) )
% of Subscription of term deposit (No) 88.73 % of Subscription of term deposit (Yes) 11.27
As seen above we have abinary classification problem with highly class-imbalanced data characterized of a ratio approximately 1:9 between the minority ('yes') and majority ('no') classes respectively. So we should evaluate the application of a method to result in a better balanced dataset.
We will investigate the dataset for missing values in the columns
df.isnull().sum()
age 0 job 0 marital 0 education 0 default 0 housing 0 loan 0 contact 0 month 0 day_of_week 0 campaign 0 pdays 0 previous 0 poutcome 0 emp.var.rate 0 cons.price.idx 0 cons.conf.idx 0 euribor3m 0 nr.employed 0 y 0 dtype: int64
df.head()
| age | job | marital | education | default | housing | loan | contact | month | day_of_week | campaign | pdays | previous | poutcome | emp.var.rate | cons.price.idx | cons.conf.idx | euribor3m | nr.employed | y | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 56 | housemaid | married | basic.4y | no | no | no | telephone | may | mon | 1 | 999 | 0 | nonexistent | 1.1 | 93.994 | -36.4 | 4.857 | 5191.0 | no |
| 1 | 57 | services | married | high.school | unknown | no | no | telephone | may | mon | 1 | 999 | 0 | nonexistent | 1.1 | 93.994 | -36.4 | 4.857 | 5191.0 | no |
| 2 | 37 | services | married | high.school | no | yes | no | telephone | may | mon | 1 | 999 | 0 | nonexistent | 1.1 | 93.994 | -36.4 | 4.857 | 5191.0 | no |
| 3 | 40 | admin. | married | basic.6y | no | no | no | telephone | may | mon | 1 | 999 | 0 | nonexistent | 1.1 | 93.994 | -36.4 | 4.857 | 5191.0 | no |
| 4 | 56 | services | married | high.school | no | no | yes | telephone | may | mon | 1 | 999 | 0 | nonexistent | 1.1 | 93.994 | -36.4 | 4.857 | 5191.0 | no |
We will split the data processing between quantitative and categorical (ordinal/nominal) variables
First let's see what is the distribution of quantitative variables
quant_var = ['age','campaign','pdays','previous','emp.var.rate',
'cons.price.idx','cons.conf.idx','euribor3m','nr.employed']
quant_var_tgt = ['age','campaign','pdays','previous','emp.var.rate',
'cons.price.idx','cons.conf.idx','euribor3m','nr.employed','y']
#df.hist(column=quant_var,figsize=(16,10))
f, axes = plt.subplots(3, 3, figsize=(16, 12))
for ax, feature in zip(axes.flat, quant_var):
graph1 = sns.kdeplot(data=df, x=feature, hue='y', alpha=0.4, ax=ax)
ax.xaxis.get_label().set_fontsize(16)
ax.yaxis.get_label().set_fontsize(16)
xticks_loc = ax.get_xticks().tolist()
yticks_loc = ax.get_yticks().tolist()
ax.tick_params(axis = 'both', which = 'major', labelsize = 14)
plt.setp(ax.get_legend().get_texts(), fontsize='14') # for legend text
plt.setp(ax.get_legend().get_title(), fontsize='16') # for legend title
plt.tight_layout()
plt.show()
#f.savefig('continuous_distribution_plot.png')
fig, axes = plt.subplots(3,3,figsize=(15,9))
for i,el in enumerate(quant_var):
a = df.boxplot(el, ax=axes.flatten()[i])
#a = df.boxplot(el, by='y', ax=axes.flatten()[i])
#fig.delaxes(axes[1,1]) # remove empty subplot
plt.tight_layout()
Ideally we would like to verify if there is any high correlation between the independent quantitative variables.
sns.set_theme(style="white")
corr_plot_var = quant_var + ['y']
# Compute the correlation matrix
corr = df[corr_plot_var].corr()
# Generate a mask for the upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))
# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(16, 10))
# Generate a custom diverging colormap
cmap = sns.diverging_palette(230, 20, as_cmap=True)
# Draw the heatmap with the mask and correct aspect ratio
ax = sns.heatmap(corr, mask=mask, cmap=cmap, vmin=-1.0, vmax=1.0, center=0,
square=True, linewidths=.5, annot=True, annot_kws={"size": 12})
ax.xaxis.get_label().set_fontsize(16)
ax.yaxis.get_label().set_fontsize(16)
ax.tick_params(axis = 'both', which = 'major', labelsize = 14)
# use matplotlib.colorbar.Colorbar object
cbar = ax.collections[0].colorbar
# here set the labelsize by 14
cbar.ax.tick_params(labelsize=14)
plt.title('Correlation Plot of Target and Features (Continuous)',fontsize=18)
#plt.savefig('continuous_correlation_plot.png')
plt.show()
No high correlation (assuming an absolute value of more than 0.5) was observed between the quantitative variables so we will proceed the analysis considering all the quantitative variables initially selected.
Ordinal
Nominal
We shall divide the analysis between Ordinal and Nominal variables and apply the necessary transformations for the training and test set independently. Ordinal variables should be handled taking into consideration the scale between the values from a discrete space. As for the Nominal variables we could create one column for each value they assume.
Specifically for the contact variable we will consider as an Ordinal variable due to assuming only 2 values and expect to eliminate redundancies if we encode to two columns.
categ_var = ['default','housing','loan','poutcome','contact','job',
'marital','education','month','day_of_week']
def cramers_corrected_stat(confusion_matrix):
chi2 = chi2_contingency(confusion_matrix)[0]
n = confusion_matrix.sum().sum()
phi2 = chi2/n
r,k = confusion_matrix.shape
phi2corr = max(0, phi2 - ((k-1)*(r-1))/(n-1))
rcorr = r - ((r-1)**2)/(n-1)
kcorr = k - ((k-1)**2)/(n-1)
return np.sqrt(phi2corr / min( (kcorr-1), (rcorr-1)))
cols = categ_var + ['y']
corrM = np.zeros((len(cols),len(cols)))
# there's probably a nice pandas way to do this
for col1, col2 in combinations(cols, 2):
idx1, idx2 = cols.index(col1), cols.index(col2)
corrM[idx1, idx2] = cramers_corrected_stat(pd.crosstab(df[col1], df[col2]))
corrM[idx2, idx1] = corrM[idx1, idx2]
corr = pd.DataFrame(corrM, index=cols, columns=cols)
fig, ax = plt.subplots(figsize=(16, 9))
ax = sns.heatmap(corr, annot=True, annot_kws={"size": 10}, cmap=cmap,ax=ax)
ax.xaxis.get_label().set_fontsize(16)
ax.yaxis.get_label().set_fontsize(16)
ax.tick_params(axis = 'both', which = 'major', labelsize = 14)
cbar = ax.collections[0].colorbar
# here set the labelsize by 14
cbar.ax.tick_params(labelsize=14)
ax.set_title("Correlation Plot of Target and Features (Categorical)", fontsize=18)
plt.tight_layout()
#fig.savefig('cramer_categorical_corr_plot.png')
fig, ax = plt.subplots(2,3,figsize=(12,6))
sns.countplot(x='default',data=df,hue='y',ax=ax[0][0])
sns.countplot(x='housing',data=df,hue='y',ax=ax[0][1])
sns.countplot(x='loan',data=df,hue='y',ax=ax[0][2])
sns.countplot(x='poutcome',data=df,hue='y',ax=ax[1][0])
sns.countplot(x='contact',data=df,hue='y',ax=ax[1][1])
fig.tight_layout()
# Default
display(df['default'].value_counts())
# Housing
display(df['housing'].value_counts())
# Loan
display(df['loan'].value_counts())
# Poutcome
display(df['poutcome'].value_counts())
# Contact
display(df['contact'].value_counts())
no 32588 unknown 8597 yes 3 Name: default, dtype: int64
yes 21576 no 18622 unknown 990 Name: housing, dtype: int64
no 33950 yes 6248 unknown 990 Name: loan, dtype: int64
nonexistent 35563 failure 4252 success 1373 Name: poutcome, dtype: int64
cellular 26144 telephone 15044 Name: contact, dtype: int64
Next we can encode the 'Yes','No','Unknown' values and 'Failure','Nonexistent','Success' from (Default, Housing, Loan, and Poutcome respectively) to ordinal numbers, because they represent scales of values in a discrete space, differently from other categorical variables which can experience a certain increase in their category size. Moreover, we know the relationship between the categories.
Therefore, each value would have a different contribution to the target variable. We will consider the OrdinalEncoder method from sklearn and define the order of categories as the following: 'Yes'/'Failure' then 'Unknown'/'Nonexistent' and lastly 'No'/'Success' and will be mapped to integer numerical values to be used as input values of the neural network subsequently.
To accomplish this we will implement the transformation inside the transform_input function (below) and create a new dataframe containing only the ordinal variables to be added later to the rest of the variables.
# enc = OrdinalEncoder(categories=[['yes','unknown','no']]*3)
# Encode Default, Housing, Loan variables
# df_ordinal = df[['default','housing', 'loan','poutcome','contact']].copy()
# df_ordinal[['default','housing', 'loan']] = enc.fit_transform(df_ordinal[['default','housing', 'loan']])
# Encode Poutcome variable
# enc = OrdinalEncoder(categories=[['failure','nonexistant','success']])
# df_ordinal[['poutcome']] = enc.fit_transform(df_ordinal[['poutcome']])
# Encode Contact variable
# enc = OrdinalEncoder(categories=[['cellular','telephone']])
# df_ordinal[['contact']] = enc.fit_transform(df_ordinal[['contact']])
Let's investigate the values assumed by the 3 remaining variables (Job/Marital/Education) to verify if any of the variables has a distance relation between values and can be transformed in a ordinal column.
# Job
print('Job')
display(df['job'].value_counts())
# Marital
print('Marital')
display(df['marital'].value_counts())
# Education
print('Education')
display(df['education'].value_counts())
Job
admin. 10422 blue-collar 9254 technician 6743 services 3969 management 2924 retired 1720 entrepreneur 1456 self-employed 1421 housemaid 1060 unemployed 1014 student 875 unknown 330 Name: job, dtype: int64
Marital
married 24928 single 11568 divorced 4612 unknown 80 Name: marital, dtype: int64
Education
university.degree 12168 high.school 9515 basic.9y 6045 professional.course 5243 basic.4y 4176 basic.6y 2292 unknown 1731 illiterate 18 Name: education, dtype: int64
# Create dummy variables
nom_var = ['job','marital','education']
df_nom_var = pd.get_dummies(df[nom_var], columns=nom_var, drop_first=True)
# Save name of the onehot columns
onehot_var = df_nom_var.columns
onehot_var
Index(['job_blue-collar', 'job_entrepreneur', 'job_housemaid',
'job_management', 'job_retired', 'job_self-employed', 'job_services',
'job_student', 'job_technician', 'job_unemployed', 'job_unknown',
'marital_married', 'marital_single', 'marital_unknown',
'education_basic.6y', 'education_basic.9y', 'education_high.school',
'education_illiterate', 'education_professional.course',
'education_university.degree', 'education_unknown', 'day_of_week_mon',
'day_of_week_thu', 'day_of_week_tue', 'day_of_week_wed'],
dtype='object')
df.month.unique()
array(['may', 'jun', 'jul', 'aug', 'oct', 'nov', 'dec', 'mar', 'apr',
'sep'], dtype=object)
sns.countplot(x='month',data=df,hue='y')
<AxesSubplot:xlabel='month', ylabel='count'>
Before applying transformation of categorical variables we should first split the data between training and test sets. One of the reasons we consider this approach is if for instance there is a rare category from one categorical variable in one (or more) of the dataset variables and this category could be in the unseen data (test set). Therefore, we cannot include this category in the training set to avoid data leakage and exposing the test data to the full distribution.
As a best practice we should encode variables by applying the fit on the encoded training dataset, then apply it on both training and test datasets.
# Define features and target variables accordingly
X = df.drop('y', axis=1).values
y = df['y'].values
y = y.reshape(y.shape[0], 1)
X_train_raw, X_test_raw, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)
print(len(X_train_raw), len(X_test_raw))
print(X_train_raw.shape, y_train.shape)
print(X_test_raw.shape, y_test.shape)
32950 8238 (32950, 19) (32950, 1) (8238, 19) (8238, 1)
# Transform input variables
def transform_input(X_train, X_test, quant_var, nom_var, onehot_var, df):
# Reconstruct features set
X_train_df = pd.DataFrame(X_train, columns = df.drop('y', axis=1).columns)
X_test_df = pd.DataFrame(X_test, columns = df.drop('y', axis=1).columns)
# Define quantitative variables
train_df_quant = X_train_df[quant_var]
test_df_quant = X_test_df[quant_var]
# Define Ordinal Encoder for Default, Housing, Loan variables
ordn = OrdinalEncoder(categories=[['yes','unknown','no']]*3)
# Fit
ordn.fit(X_train_df[['default','housing', 'loan']]) #train_df_ordinal[['default','housing', 'loan']])
# Transform
X_train_ord1 = ordn.transform(X_train_df[['default','housing', 'loan']])
X_test_ord1 = ordn.transform(X_test_df[['default','housing', 'loan']])
# Define Ordinal Encoder for Poutcome variable
ordn = OrdinalEncoder(categories=[['failure','nonexistent','success']])
# Fit
ordn.fit(np.array(X_train_df['poutcome']).reshape(-1,1))
# Transform
X_train_ord2 = ordn.transform(X_train_df[['poutcome']])
X_test_ord2 = ordn.transform(X_test_df[['poutcome']])
# Define Ordinal Encoder for Contact variable
ordn = OrdinalEncoder(categories=[['cellular','telephone']])
# Fit
ordn.fit(np.array(X_train_df['contact']).reshape(-1,1))
# Transform
X_train_ord3 = ordn.transform(X_train_df[['contact']])
X_test_ord3 = ordn.transform(X_test_df[['contact']])
# Define Ordinal Encoder for Contact variable
ordn = OrdinalEncoder(categories=[['mon','tue','wed','thu','fri']])
# Fit
ordn.fit(np.array(X_train_df['day_of_week']).reshape(-1,1))
# Transform
X_train_ord4 = ordn.transform(X_train_df[['day_of_week']])
X_test_ord4 = ordn.transform(X_test_df[['day_of_week']])
# Define Ordinal Encoder for Month variable
ordn = ce.OrdinalEncoder(cols=['month'], return_df=True, mapping = [{
'col': 'month', 'mapping': {
'jan': 1, 'feb': 2, 'mar': 3, \
'apr': 4, 'may': 5, 'jun': 6, 'jul': 7, \
'aug': 8, 'sep': 9, 'oct': 10, 'nov': 11, 'dec': 12}}])
# Fit
ordn.fit(X_train_df['month'])
# Transform
X_train_ord5 = ordn.transform(X_train_df[['month']])
X_test_ord5 = ordn.transform(X_test_df[['month']])
# Define OneHot Encoder
ohe = OneHotEncoder(handle_unknown='error') #,drop='first',
# Fit
ohe.fit(X_train_df[nom_var])
onehot_list = ohe.get_feature_names(nom_var)
# Transform
X_train_nom = ohe.transform(X_train_df[nom_var])
X_test_nom = ohe.transform(X_test_df[nom_var])
# Create dataframe (training)
train_df_ord1 = pd.DataFrame(X_train_ord1, columns = ['default','housing', 'loan'])
train_df_ord2 = pd.DataFrame(X_train_ord2, columns = ['poutcome'])
train_df_ord3 = pd.DataFrame(X_train_ord3, columns = ['contact'])
train_df_ord4 = pd.DataFrame(X_train_ord4, columns = ['day_of_week'])
train_df_ord5 = pd.DataFrame(X_train_ord5, columns = ['month'])
train_df_nom = pd.DataFrame(X_train_nom.toarray(), columns = list(onehot_list))
# Create dataframe (test)
test_df_ord1 = pd.DataFrame(X_test_ord1, columns = ['default','housing', 'loan'])
test_df_ord2 = pd.DataFrame(X_test_ord2, columns = ['poutcome'])
test_df_ord3 = pd.DataFrame(X_test_ord3, columns = ['contact'])
test_df_ord4 = pd.DataFrame(X_test_ord4, columns = ['day_of_week'])
test_df_ord5 = pd.DataFrame(X_test_ord5, columns = ['month'])
test_df_nom = pd.DataFrame(X_test_nom.toarray(), columns = list(onehot_list))
# Concatenate dataframes (quantitative / categorical) for training / test
train_frames = [train_df_ord1, train_df_ord2, train_df_ord3, train_df_ord4,
train_df_ord5, train_df_nom, train_df_quant]
test_frames = [test_df_ord1, test_df_ord2, test_df_ord3, test_df_ord4,
test_df_ord5, test_df_nom, test_df_quant]
train_df_trans = pd.concat(train_frames, axis=1)
test_df_trans = pd.concat(test_frames, axis=1)
#print(X_train_trans, X_test_trans)
return train_df_trans, test_df_trans
# Transform target variable
def transform_target(y_train, y_test):
le = LabelEncoder()
le.fit(np.ravel(y_train))
y_train_trans = le.transform(np.ravel(y_train))
y_test_trans = le.transform(np.ravel(y_test))
return y_train_trans, y_test_trans
x_quant = df[quant_var].values #returns a numpy array
scaler_quant = preprocessing.StandardScaler()
x_quant_scaled = scaler_quant.fit_transform(x_quant)
df_quant_var = pd.DataFrame(x_quant_scaled, columns = list(quant_var))
df_quant_var = df_quant_var.apply(lambda s: s.apply('{0:.2f}'.format)).astype(float)
# Concatenate scaled dataframe with categorical variables dataframe to print and save a table with descriptive statistics
desc_frames = [df_quant_var,df[categ_var],df['y']]
df_desc_stats = pd.concat(desc_frames, axis=1 )
temp_desc = df_desc_stats.groupby('y').describe(include='all').round(decimals=2).reset_index()
temp_desc2 = temp_desc.unstack().reset_index().loc[2:,:].reset_index()
temp_desc2.rename(columns = {0:'values'}, inplace = True)
df_desc_stats_final = temp_desc2.pivot_table(index='level_0',columns=['level_2','level_1'],values='values',aggfunc='first')
colclass0 = ['palegreen'] * 11
colclass1 = ['royalblue'] * 11
colClass = colclass0 + colclass1
# Statistical Summary
from pandas.plotting import table
df2 = df.describe()
nrows, ncols = len(df2)+1, len(df2.columns)
hcell, wcell = 0.3, 0.1 # tweak as per your requirements
hpad, wpad = 0.5, 0.5
fig, ax = plt.subplots(figsize=(8, 14), frameon=False, dpi=200)#figsize=(ncols*wcell+wpad, nrows*hcell+hpad))
table = ax.table(cellText=df_desc_stats_final.values, colWidths = [0.25]*len(df_desc_stats_final.columns),
rowLabels=df_desc_stats_final.index,
colLabels=df_desc_stats_final.columns,
colColours=colClass,
loc='center')
table.auto_set_font_size(False)
table.set_fontsize(18)
table.scale(10,6)
table.auto_set_column_width(col=list(range(len(df_desc_stats_final.columns))))
ax.axis('off')
ax.axis('tight')
plt.show()
# Workaround
fig.canvas.draw()
#save the plot as a png file
#fig.savefig('descriptive_statistics_plot.png', bbox_inches='tight')
Now we can call the functions and have the resultant transformed sets
# Transform input data (X)
train_df_trans, test_df_trans = transform_input(X_train_raw, X_test_raw, quant_var, nom_var, onehot_var, df)
# Transform output data (y)
y_train_trans, y_test_trans = transform_target(y_train, y_test)
# Prepare inputs
X_train = np.array(train_df_trans)
X_test = np.array(test_df_trans)
# Reshape target variable
y_train = y_train_trans.reshape(-1,1)
y_test = y_test_trans.reshape(-1,1)
Now we verify the dimensions of encoded arrays
print(X_train.shape, X_test.shape)
print(y_train.shape, y_test.shape)
(32950, 40) (8238, 40) (32950, 1) (8238, 1)
# max_epochs = 100
# learning_rate = 0.003
# input_dim = 50
# hidden_dim = 50
# output_dim = 2
# # Dataloader Parameters
# params_train = {'batch_size': 64,
# 'shuffle': True,
# 'num_workers': 6}
# params_test = {'batch_size': 1024,
# 'shuffle': False,
# 'num_workers': 6}
# def smote_train_val_loader(x_train, y_train, batch_size_train=64, batch_size_val=64):
# # Shuffle the indices
# indices = np.arange(0,len(x_train)) # build an array
# np.random.shuffle(indices) # shuffle the indicies
# X_train, X_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1, random_state=42)
# # Define SMOTE instance
# smt = SMOTE(sampling_strategy='minority', random_state=2)
# # Apply SMOTE to training only
# X_train_smote, y_train_smote = smt.fit_resample(X_train,#[indices[:round(len(X_train)*0.8)]],
# y_train) #[indices[:round(len(y_train)*0.8)]])
# # Create X_val and y_val according to the shuffled indexes (Validate on 10% of the imbalanced data)
# #X_val = X_train[indices[-(len(X_train)-round(len(X_train)*0.8)):]]
# #y_val = y_train[indices[-(len(y_train)-round(len(y_train)*0.8)):]]
# # Train set
# # Create feature and targets tensor for training set
# featuresTrain = torch.from_numpy(X_train_smote).type(torch.FloatTensor)
# targetsTrain = torch.from_numpy(y_train_smote).type(torch.FloatTensor)
# # Val set
# # Create feature and targets tensor for validate set
# featuresVal = torch.from_numpy(X_val).type(torch.FloatTensor)
# targetsVal = torch.from_numpy(y_val).type(torch.FloatTensor)
# # Tensor Datasets
# train_data = torch.utils.data.TensorDataset(featuresTrain,targetsTrain)
# val_data = torch.utils.data.TensorDataset(featuresVal,targetsVal)
# # Dataloader Generators
# train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size_train,shuffle = True,
# num_workers= 0)
# val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size_val,shuffle = True,
# num_workers= 0)
# return train_loader, val_loader
# def test_data_loader(x_test, y_test):
# # Test set
# # create feature and targets tensor for validate sets. We need variable to accumulate gradients.
# featuresTest = torch.from_numpy(x_test).type(torch.FloatTensor)
# targetsTest = torch.from_numpy(y_test).type(torch.FloatTensor) # data type is long
# # Tensor Dataset
# test_data = torch.utils.data.TensorDataset(featuresTest,targetsTest)
# # Dataloader Generators
# #training_set = Dataset(X_sm, y_sm)
# test_loader = torch.utils.data.DataLoader(test_data, batch_size= 128,shuffle = False, num_workers= 0)
# return test_loader
class NeuralNet(nn.Module):
def __init__(self, input_dim=40, hidden_dim=48, output_dim=2, dropout=0.5):
super(NeuralNet, self).__init__()
# self.classifier = nn.Sequential(
# # 1st hidden layer
# nn.Linear(input_dim, hidden_dim),
# nn.BatchNorm1d(hidden_dim), #applying batch norm
# nn.ReLU(),
# # 2nd hidden layer
# nn.Linear(hidden_dim, round(hidden_dim/2)),
# nn.BatchNorm1d(round(hidden_dim/2)),
# nn.ReLU(),
# nn.Linear(round(hidden_dim/2), output_dim),
# nn.Sigmoid()
# )
self.hidden_dim = hidden_dim
self.fcl1 = nn.Linear(input_dim, hidden_dim)
self.fcl2 = nn.Linear(hidden_dim, round(hidden_dim/2))
#self.fcl3 = nn.Linear(hidden_dim, hidden_dim)
#self.fcl4 = nn.Linear(hidden_dim, hidden_dim)
self.output = nn.Linear(round(hidden_dim/2), output_dim)
self.batchnorm1 = nn.BatchNorm1d(hidden_dim) # Batch Normalization (First Layer)
self.batchnorm2 = nn.BatchNorm1d(round(hidden_dim/2)) # Batch Normalization (Second Layer)
#self.batchnorm3 = nn.BatchNorm1d(hidden_dim) # Batch Normalization (Third Layer)
#self.batchnorm4 = nn.BatchNorm1d(hidden_dim) # Batch Normalization (Fourth Layer)
self.dropout = nn.Dropout(p=dropout) # Dropout
# Initialize weights with Uniform distribution (First Layer)
nn.init.uniform_(self.fcl1.weight)
# Initialize bias with zeros (First Layer)
nn.init.zeros_(self.fcl1.bias)
# Initialize weights with Uniform distribution (Second Layer)
nn.init.uniform_(self.fcl2.weight)
# Initialize bias with zeros (Second Layer)
nn.init.zeros_(self.fcl2.bias)
# Initialize weights with Uniform distribution (Third Layer)
#nn.init.uniform_(self.fcl3.weight)
# Initialize bias with zeros (Third Layer)
#nn.init.zeros_(self.fcl3.bias)
# Initialize weights with Uniform distribution (Fourth Layer)
#nn.init.uniform_(self.fcl4.weight)
# Initialize bias with zeros (Fourth Layer)
#nn.init.zeros_(self.fcl4.bias)
def forward(self, x):
x = self.dropout(F.elu(self.batchnorm1(self.fcl1(x)))) # First Hidden Layer
x = self.dropout(F.elu(self.batchnorm2(self.fcl2(x)))) # Second Hidden Layer
#x = self.dropout(F.elu(self.batchnorm3(self.fcl3(x)))) # Third Hidden Layer
#x = self.dropout(F.elu(self.batchnorm4(self.fcl4(x)))) # Fourth Hidden Layer
x = self.output(x)
return x
model = NeuralNet()
# Tensor Datasets
featuresTrain = torch.from_numpy(X_train.astype(np.float32)).type(torch.FloatTensor)
targetsTrain = torch.from_numpy(y_train.astype(np.int64).squeeze(1)).type(torch.FloatTensor)
# Tensor Datasets
train_data = torch.utils.data.TensorDataset(featuresTrain,targetsTrain)
# Dataloader Generators
train_loader = torch.utils.data.DataLoader(train_data, batch_size= 32,shuffle = True,num_workers= 0)
batch = next(iter(train_loader))
y_dummy = model(batch[0]) #Dummy batch
make_dot(y_dummy, params=dict(list(model.named_parameters())))
network = NeuralNet()
for name, param in network.named_parameters():
if param.requires_grad:
print(name, param.data)
fcl1.weight tensor([[0.9475, 0.1563, 0.4869, ..., 0.7679, 0.7637, 0.2424],
[0.4891, 0.8553, 0.5755, ..., 0.8138, 0.1891, 0.2430],
[0.2277, 0.1090, 0.0480, ..., 0.3774, 0.0670, 0.3274],
...,
[0.8521, 0.9357, 0.0014, ..., 0.0505, 0.9562, 0.3413],
[0.4872, 0.5581, 0.2291, ..., 0.0362, 0.8850, 0.5742],
[0.4028, 0.3190, 0.8566, ..., 0.5292, 0.7557, 0.6127]])
fcl1.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
fcl2.weight tensor([[0.5044, 0.4224, 0.4360, ..., 0.4902, 0.5276, 0.2539],
[0.3519, 0.5062, 0.0737, ..., 0.9800, 0.7191, 0.3877],
[0.6108, 0.1621, 0.7972, ..., 0.8015, 0.0025, 0.2232],
...,
[0.6278, 0.4434, 0.8417, ..., 0.7140, 0.1842, 0.1197],
[0.5449, 0.9856, 0.8773, ..., 0.3798, 0.5957, 0.6340],
[0.3335, 0.4141, 0.7281, ..., 0.5459, 0.6109, 0.8009]])
fcl2.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
output.weight tensor([[-0.0569, 0.1488, 0.0333, 0.1079, 0.1577, -0.0497, -0.1738, 0.1641,
0.1100, -0.0616, 0.1301, -0.0468, -0.1294, 0.0115, 0.0399, -0.0817,
-0.1880, 0.0978, -0.0501, -0.0016, 0.1887, 0.1700, -0.1928, 0.0095],
[-0.0508, -0.1641, -0.1597, 0.1092, 0.0144, -0.1146, 0.0069, -0.1952,
-0.1772, -0.1869, 0.1572, 0.0399, -0.0829, -0.0358, 0.1415, -0.1881,
-0.1402, -0.1212, 0.1431, -0.0164, 0.0414, 0.1928, 0.1246, 0.1919]])
output.bias tensor([-0.1618, 0.1801])
batchnorm1.weight tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
batchnorm1.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
batchnorm2.weight tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1.])
batchnorm2.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Now we feed our model to the skorch classifier and define a skorch object with initial parameters before optimization.
# Define class weights (If necessary)
#weights = [1.0, 3.0]
#class_weights = torch.FloatTensor(weights)
# Define checkpoints for training and validation loss
#monitor_loss = lambda nnet: all(nnet.history[-1, ('train_loss_best', 'valid_loss_best')])
# Define checkpoints for training and validation acc
#monitor_acc = lambda nnet: nnet.history[-1, ('valid_acc',)]
# Define epoch scoring functions
rocauc = EpochScoring(scoring='roc_auc', lower_is_better=False)
tr_acc = EpochScoring(scoring='accuracy', name='train_acc', on_train=True, lower_is_better=False)
#nnet = NeuralNetBinaryClassifier(
nnet = NeuralNetClassifier(
module=NeuralNet,
max_epochs=350,
batch_size=32,
#criterion=torch.nn.BCEWithLogitsLoss,
#criterion=torch.nn.BCELoss,
criterion=nn.CrossEntropyLoss,
#criterion__reduction='mean',
lr=0.0001,
callbacks=[EarlyStopping(patience=10), tr_acc, rocauc], #, Checkpoint(monitor=monitor_loss)],
# Shuffle training data on each epoch
iterator_train__shuffle=True,
#optimizer__momentum=0.9,
optimizer=torch.optim.Adam,
device=device
)
# deactivate skorch-internal train-valid split and verbose logging
#nnet.set_params(train_split=False, verbose=0)
# # CrossEntropyLoss
# nnet.fit(X_train_net.astype(np.float32), y_train_net.astype(np.int64).squeeze(1))
# y_proba = nnet.predict_proba(X_train_net.astype(np.float32))
# BCEWithLogitsLoss
# nnet.fit(X_train_net.astype(np.float32), y_train_net.astype(np.float32).squeeze(1))
# y_proba = nnet.predict_proba(X_train_net.astype(np.float32))
# Undersample and plot imbalanced dataset with One-Sided Selection
from collections import Counter
from sklearn.datasets import make_classification
from imblearn.under_sampling import OneSidedSelection
from matplotlib import pyplot
from numpy import where
from imblearn.under_sampling import EditedNearestNeighbours
# summarize class distribution
counter = Counter(y_train[:,0])
print(counter)
# define the undersampling method
over = SMOTE(sampling_strategy=0.5, random_state=2 ,k_neighbors=7)
#under = EditedNearestNeighbours(sampling_strategy='majority', n_neighbors=7)
randunder = RandomUnderSampler(sampling_strategy='majority', random_state=2)
# transform the dataset
X_new, y_new = over.fit_resample(X_train, y_train)
#X_new, y_new = under.fit_resample(X_new, y_new)
X_new, y_new = randunder.fit_resample(X_new, y_new)
# summarize the new class distribution
counter = Counter(y_new)
print(counter)
# scatter plot of examples by class label
for label, _ in counter.items():
row_ix = where(y_new == label)[0]
pyplot.scatter(X_new[row_ix, 0], X_new[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()
#Counter({0: 13649, 1: 3712})
Counter({0: 29238, 1: 3712})
/usr/local/lib/python3.7/dist-packages/sklearn/utils/validation.py:760: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
Counter({0: 14619, 1: 14619})
def train_mlp_model_selection(X_train_model,y_train_model,n_iter=10):
scaler = preprocessing.StandardScaler()
# Define resampling technique
#smote_only = SMOTE(random_state=2) #sampling_strategy='minority'
#under = EditedNearestNeighbours(sampling_strategy='majority', n_neighbors=7)
over = SMOTE(sampling_strategy=0.2, random_state=2 ,k_neighbors=7)
rand_under = RandomUnderSampler(sampling_strategy='majority', random_state=2)
# Create a Imbalance Pipeline with Over Sampling and Under Sampling
nnet_pipeline = imbPipeline([('scaler',scaler),
#('smoteonly', smote_only),
('o', over), #('u', under),
('ru', rand_under),
('nnet', nnet)])
# Deactivate parameter grid used for GridSearchCV
# params_gridcv = {
# 'nnet__lr': [0.0003, 0.00001], #, 0.00003, 0.0001, 0.0003, 0.001, 0.01],
# 'nnet__max_epochs': [150], #, 20, 30
# 'nnet__module__dropout': [0.5], #[0.4, 0.5, 0.6],
# 'nnet__module__hidden_dim': [48], #48, 96, 128],
# #'nnet__optimizer__momentum':[0.5, 0.9],
# #'optimizer': [optim.SGD] #, optim.RMSprop
# 'nnet__optimizer__weight_decay':[1e-1, 1e-2]
# }
#gs = GridSearchCV(smp_pipeline, params, refit=False, cv=skf, scoring='average_precision',verbose=2)
#gs.fit(X_test_net.astype(np.float32), y_test_net.astype(np.int64).squeeze(1))
#nnet.fit(X_train_model.astype(np.float32), y_train_model.astype(np.int64).squeeze(1))
#y_proba = nnet.predict_proba(X_train_model.astype(np.float32))
params_randcv ={
'nnet__batch_size': [32],
'nnet__module__hidden_dim':randint(20,80),
'nnet__module__dropout': [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], #uniform(0, 1),
'nnet__lr': loguniform(1e-5, 1e-1),
'nnet__optimizer__weight_decay': loguniform(1e-6, 1e-1),
'nnet__optimizer': [AdaBound, optim.Adam], #, optim.RMSprop, optim.Adagrad, optim.Adadelta, optim.SGD, optim.Adamax],
'nnet__max_epochs': randint(50,350)
}
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2)
scorersMLP = {
'precision_score': make_scorer(precision_score, zero_division=1),
'recall_score': make_scorer(recall_score, zero_division=1),
'accuracy_score': make_scorer(accuracy_score),
'roc_auc_score': make_scorer(roc_auc_score),
'average_precision_score': make_scorer(average_precision_score),
'f1_score': make_scorer(f1_score)
}
nnet_pipeline.fit(X_train_model.astype(np.float32), y_train_model.astype(np.int64).squeeze(1))
#y_proba = nnet_pipeline.predict_proba(X_train_model.astype(np.float32))
mlp_rs = RandomizedSearchCV(nnet_pipeline, params_randcv, refit='roc_auc_score',
cv=skf, scoring=scorersMLP, return_train_score = True,
n_iter=n_iter, random_state=123, verbose=10, n_jobs=-1) #, verbose=100)
start = time.time()
mlp_model_selection = mlp_rs.fit(X_train_model.astype(np.float32), y_train_model.astype(np.int64).squeeze(1))
totaltime = time.time() - start
print("RandomizedSearchCV (MLP) took %.2f seconds (%.2f hours) for %d selected candidates"
" parameter settings." % ((totaltime), (totaltime/3600), n_iter))
print("Best params: {}".format(mlp_model_selection.best_params_))
print("Best scores: {}".format(mlp_model_selection.best_score_))
return mlp_model_selection, nnet
mlp_model_selection, nnet_selection = train_mlp_model_selection(X_train,y_train,n_iter = 64)
Re-initializing module.
Re-initializing optimizer.
epoch roc_auc train_acc train_loss valid_acc valid_loss dur
------- --------- ----------- ------------ ----------- ------------ ------
1 0.6620 0.5798 0.7070 0.6434 0.6636 0.8425
2 0.6688 0.6027 0.6937 0.6520 0.6558 0.7819
3 0.6748 0.6149 0.6886 0.6610 0.6526 0.8051
4 0.6796 0.6356 0.6721 0.6648 0.6487 0.7800
5 0.6845 0.6434 0.6666 0.6691 0.6448 0.7674
6 0.6889 0.6448 0.6672 0.6691 0.6406 0.8181
7 0.6919 0.6526 0.6573 0.6794 0.6399 0.7951
8 0.6960 0.6627 0.6547 0.6764 0.6355 0.8044
9 0.6981 0.6662 0.6492 0.6729 0.6344 0.8410
10 0.7010 0.6719 0.6428 0.6802 0.6314 0.8271
11 0.7040 0.6733 0.6428 0.6879 0.6287 0.7977
12 0.7060 0.6770 0.6397 0.6845 0.6278 0.7993
13 0.7083 0.6804 0.6403 0.6866 0.6259 0.8043
14 0.7101 0.6855 0.6332 0.6905 0.6242 0.7927
15 0.7124 0.6895 0.6306 0.6913 0.6231 0.7911
16 0.7143 0.6863 0.6300 0.6917 0.6214 0.8414
17 0.7165 0.6878 0.6294 0.6841 0.6224 0.7947
18 0.7180 0.6887 0.6272 0.6977 0.6185 0.7936
19 0.7200 0.6911 0.6221 0.7037 0.6164 0.7822
20 0.7221 0.6900 0.6264 0.6994 0.6163 0.7887
21 0.7233 0.6934 0.6225 0.6943 0.6172 0.7736
22 0.7259 0.6946 0.6200 0.6994 0.6146 0.7917
23 0.7269 0.6955 0.6195 0.6986 0.6144 0.7964
24 0.7287 0.6965 0.6198 0.7033 0.6126 0.8049
25 0.7302 0.7007 0.6205 0.7050 0.6117 0.8109
26 0.7323 0.6982 0.6153 0.7076 0.6097 0.8387
27 0.7333 0.7006 0.6139 0.7033 0.6104 0.7980
28 0.7349 0.6999 0.6170 0.7076 0.6083 0.7998
29 0.7362 0.7010 0.6118 0.7037 0.6093 0.8071
30 0.7378 0.7028 0.6119 0.7084 0.6066 0.7885
31 0.7387 0.7024 0.6119 0.7054 0.6072 0.8024
32 0.7402 0.6996 0.6116 0.7054 0.6064 0.8254
33 0.7415 0.7007 0.6100 0.7101 0.6046 0.8163
34 0.7426 0.7026 0.6106 0.7110 0.6040 0.8211
35 0.7439 0.7036 0.6082 0.7076 0.6048 0.8026
36 0.7451 0.7029 0.6079 0.7110 0.6026 0.7998
37 0.7459 0.7044 0.6086 0.7118 0.6022 0.8036
38 0.7472 0.7038 0.6091 0.7106 0.6021 0.8287
39 0.7481 0.7052 0.6092 0.7123 0.6010 0.8189
40 0.7489 0.7052 0.6036 0.7097 0.6012 0.7725
41 0.7499 0.7052 0.6066 0.7101 0.6009 0.7938
42 0.7507 0.7040 0.6048 0.7080 0.6017 0.7852
43 0.7521 0.7052 0.6080 0.7127 0.5992 0.8026
44 0.7526 0.7049 0.6049 0.7084 0.6004 0.7810
45 0.7533 0.7056 0.6066 0.7127 0.5984 0.7879
46 0.7542 0.7036 0.6057 0.7084 0.5996 0.8222
47 0.7550 0.7067 0.6041 0.7118 0.5981 0.8442
48 0.7556 0.7055 0.6051 0.7097 0.5984 0.8042
49 0.7565 0.7074 0.6038 0.7123 0.5964 0.8276
50 0.7575 0.7059 0.6033 0.7165 0.5949 0.8145
51 0.7583 0.7074 0.5992 0.7153 0.5946 0.8063
52 0.7587 0.7064 0.6027 0.7127 0.5959 0.8212
53 0.7594 0.7086 0.5997 0.7127 0.5950 0.8125
54 0.7602 0.7064 0.5990 0.7131 0.5945 0.7917
55 0.7608 0.7096 0.5966 0.7131 0.5941 0.7999
56 0.7615 0.7070 0.5986 0.7148 0.5929 0.8200
57 0.7622 0.7076 0.5980 0.7148 0.5926 0.8275
58 0.7627 0.7080 0.5992 0.7144 0.5930 0.8033
59 0.7635 0.7075 0.5993 0.7170 0.5907 0.8165
60 0.7639 0.7081 0.5970 0.7148 0.5911 0.8316
61 0.7644 0.7064 0.5984 0.7148 0.5923 0.8231
62 0.7649 0.7091 0.5979 0.7170 0.5900 0.7834
63 0.7655 0.7077 0.5976 0.7153 0.5899 0.7995
64 0.7660 0.7075 0.5971 0.7178 0.5883 0.7895
65 0.7666 0.7083 0.5962 0.7174 0.5886 0.8254
66 0.7670 0.7101 0.5949 0.7174 0.5879 0.8075
67 0.7673 0.7091 0.5963 0.7165 0.5887 0.7913
68 0.7678 0.7095 0.5937 0.7153 0.5891 0.8187
69 0.7683 0.7096 0.5943 0.7178 0.5860 0.8223
70 0.7686 0.7073 0.5952 0.7170 0.5875 0.8187
71 0.7690 0.7091 0.5920 0.7170 0.5868 0.8114
72 0.7692 0.7112 0.5919 0.7187 0.5850 0.7817
73 0.7696 0.7097 0.5909 0.7187 0.5842 0.7950
74 0.7698 0.7108 0.5912 0.7165 0.5855 0.7678
75 0.7701 0.7118 0.5912 0.7170 0.5846 0.8053
76 0.7705 0.7099 0.5917 0.7165 0.5843 0.7876
77 0.7708 0.7100 0.5909 0.7161 0.5843 0.8226
78 0.7711 0.7106 0.5890 0.7170 0.5836 0.8138
79 0.7714 0.7099 0.5906 0.7174 0.5823 0.8177
80 0.7716 0.7118 0.5907 0.7187 0.5817 0.7980
81 0.7720 0.7088 0.5911 0.7183 0.5819 0.7833
82 0.7723 0.7128 0.5875 0.7187 0.5808 0.7812
83 0.7725 0.7119 0.5868 0.7187 0.5802 0.7959
84 0.7728 0.7118 0.5871 0.7217 0.5783 0.7994
85 0.7730 0.7107 0.5888 0.7178 0.5800 0.8290
86 0.7732 0.7114 0.5891 0.7191 0.5789 0.8154
87 0.7735 0.7117 0.5862 0.7200 0.5774 0.7962
88 0.7737 0.7119 0.5862 0.7195 0.5775 0.7966
89 0.7739 0.7159 0.5850 0.7195 0.5765 0.7918
90 0.7742 0.7123 0.5834 0.7183 0.5771 0.7950
91 0.7743 0.7111 0.5833 0.7183 0.5769 0.8137
92 0.7746 0.7143 0.5849 0.7217 0.5742 0.7891
93 0.7747 0.7105 0.5837 0.7178 0.5763 0.8099
94 0.7750 0.7141 0.5832 0.7195 0.5743 0.7948
95 0.7750 0.7143 0.5827 0.7153 0.5769 0.8110
96 0.7753 0.7158 0.5822 0.7221 0.5726 0.8194
97 0.7754 0.7120 0.5829 0.7187 0.5735 0.7713
98 0.7756 0.7129 0.5827 0.7165 0.5733 0.7873
99 0.7758 0.7149 0.5832 0.7178 0.5725 0.8052
100 0.7759 0.7162 0.5804 0.7157 0.5732 0.8158
101 0.7761 0.7103 0.5817 0.7178 0.5717 0.8161
102 0.7762 0.7150 0.5813 0.7212 0.5707 0.7867
103 0.7763 0.7118 0.5813 0.7191 0.5708 0.8613
104 0.7765 0.7139 0.5791 0.7242 0.5688 0.8483
105 0.7767 0.7141 0.5799 0.7230 0.5684 0.8102
106 0.7768 0.7134 0.5777 0.7225 0.5680 0.7880
107 0.7770 0.7156 0.5794 0.7208 0.5690 0.8049
108 0.7771 0.7173 0.5753 0.7183 0.5698 0.8077
109 0.7772 0.7167 0.5741 0.7174 0.5709 0.8274
110 0.7774 0.7139 0.5781 0.7212 0.5674 0.7911
111 0.7775 0.7159 0.5779 0.7208 0.5671 0.8029
112 0.7777 0.7145 0.5735 0.7204 0.5673 0.8014
113 0.7778 0.7168 0.5752 0.7230 0.5652 0.8231
114 0.7778 0.7206 0.5760 0.7255 0.5647 0.7712
115 0.7779 0.7149 0.5767 0.7281 0.5642 0.8211
116 0.7781 0.7195 0.5765 0.7268 0.5644 0.8105
117 0.7782 0.7164 0.5732 0.7247 0.5635 0.7806
118 0.7783 0.7137 0.5758 0.7242 0.5640 0.7968
119 0.7783 0.7173 0.5758 0.7234 0.5645 0.7976
120 0.7785 0.7164 0.5765 0.7242 0.5633 0.8042
121 0.7786 0.7198 0.5738 0.7277 0.5620 0.7981
122 0.7787 0.7170 0.5731 0.7234 0.5634 0.7984
123 0.7788 0.7172 0.5736 0.7281 0.5622 0.7908
124 0.7789 0.7137 0.5735 0.7277 0.5616 0.7933
125 0.7790 0.7189 0.5741 0.7247 0.5626 0.8471
126 0.7791 0.7152 0.5744 0.7234 0.5642 0.8007
127 0.7791 0.7188 0.5702 0.7277 0.5607 0.7975
128 0.7792 0.7185 0.5719 0.7277 0.5605 0.8417
129 0.7793 0.7187 0.5692 0.7289 0.5600 0.8038
130 0.7793 0.7212 0.5702 0.7281 0.5600 0.7875
131 0.7794 0.7188 0.5743 0.7294 0.5604 0.7993
132 0.7794 0.7200 0.5729 0.7294 0.5597 0.8084
133 0.7796 0.7134 0.5756 0.7311 0.5595 0.7980
134 0.7796 0.7195 0.5724 0.7289 0.5596 0.8087
135 0.7797 0.7180 0.5700 0.7281 0.5588 0.8120
136 0.7798 0.7158 0.5735 0.7268 0.5610 0.8003
137 0.7798 0.7170 0.5697 0.7298 0.5588 0.8230
138 0.7798 0.7211 0.5690 0.7294 0.5582 0.8085
139 0.7799 0.7160 0.5710 0.7302 0.5581 0.8269
140 0.7800 0.7185 0.5712 0.7298 0.5577 0.8119
141 0.7800 0.7207 0.5734 0.7272 0.5589 0.8109
142 0.7800 0.7153 0.5720 0.7289 0.5584 0.8582
143 0.7801 0.7152 0.5744 0.7268 0.5589 0.7973
144 0.7802 0.7200 0.5686 0.7302 0.5572 0.7858
145 0.7803 0.7168 0.5718 0.7311 0.5579 0.8265
146 0.7802 0.7170 0.5700 0.7289 0.5572 0.8068
147 0.7804 0.7164 0.5730 0.7251 0.5592 0.8072
148 0.7805 0.7177 0.5723 0.7332 0.5566 0.7974
149 0.7804 0.7167 0.5705 0.7298 0.5577 0.7912
150 0.7805 0.7203 0.5706 0.7289 0.5573 0.7959
151 0.7805 0.7189 0.5695 0.7289 0.5565 0.8129
152 0.7806 0.7164 0.5688 0.7285 0.5567 0.8256
153 0.7805 0.7244 0.5666 0.7302 0.5564 0.8084
154 0.7807 0.7162 0.5695 0.7268 0.5579 0.8344
155 0.7807 0.7184 0.5708 0.7307 0.5557 0.8215
156 0.7807 0.7159 0.5693 0.7319 0.5557 0.8200
157 0.7807 0.7172 0.5688 0.7247 0.5582 0.9542
158 0.7808 0.7190 0.5686 0.7307 0.5554 0.7851
159 0.7808 0.7194 0.5669 0.7307 0.5554 0.9409
160 0.7808 0.7198 0.5670 0.7311 0.5550 0.8661
161 0.7808 0.7183 0.5660 0.7251 0.5563 0.8260
162 0.7809 0.7204 0.5653 0.7315 0.5546 0.7925
163 0.7809 0.7207 0.5691 0.7311 0.5547 0.7806
164 0.7809 0.7137 0.5709 0.7264 0.5559 0.8239
165 0.7810 0.7172 0.5695 0.7307 0.5542 0.7687
166 0.7809 0.7209 0.5629 0.7242 0.5559 0.8090
167 0.7810 0.7167 0.5717 0.7302 0.5554 0.8525
168 0.7811 0.7170 0.5688 0.7319 0.5543 0.8005
169 0.7810 0.7199 0.5645 0.7307 0.5540 0.8126
170 0.7810 0.7167 0.5672 0.7311 0.5538 0.8058
171 0.7811 0.7199 0.5681 0.7307 0.5540 0.8015
172 0.7812 0.7234 0.5620 0.7311 0.5538 0.8017
173 0.7811 0.7163 0.5685 0.7307 0.5541 0.7994
174 0.7811 0.7177 0.5669 0.7307 0.5542 0.7822
175 0.7812 0.7204 0.5682 0.7311 0.5545 0.8575
176 0.7813 0.7214 0.5650 0.7302 0.5533 0.8074
177 0.7812 0.7159 0.5671 0.7234 0.5551 0.8406
178 0.7812 0.7241 0.5656 0.7315 0.5537 0.7753
179 0.7813 0.7167 0.5700 0.7298 0.5532 0.7998
180 0.7813 0.7165 0.5649 0.7289 0.5546 0.7839
181 0.7813 0.7224 0.5671 0.7319 0.5530 0.7770
182 0.7813 0.7145 0.5687 0.7315 0.5536 0.7898
183 0.7813 0.7197 0.5665 0.7324 0.5536 0.7845
184 0.7814 0.7147 0.5648 0.7315 0.5536 0.7976
185 0.7814 0.7209 0.5682 0.7311 0.5541 0.7842
186 0.7814 0.7185 0.5639 0.7324 0.5529 0.7835
187 0.7816 0.7213 0.5628 0.7307 0.5526 0.8293
188 0.7814 0.7207 0.5668 0.7302 0.5530 0.7931
189 0.7815 0.7189 0.5652 0.7200 0.5579 0.8012
190 0.7814 0.7208 0.5662 0.7298 0.5532 0.7668
191 0.7814 0.7192 0.5651 0.7315 0.5534 0.7749
192 0.7815 0.7179 0.5670 0.7289 0.5541 0.8419
193 0.7815 0.7173 0.5678 0.7315 0.5534 0.8092
194 0.7815 0.7168 0.5668 0.7319 0.5523 0.7892
195 0.7815 0.7176 0.5649 0.7319 0.5524 0.8240
196 0.7815 0.7164 0.5674 0.7302 0.5526 0.7965
197 0.7816 0.7207 0.5612 0.7242 0.5546 0.8066
198 0.7816 0.7162 0.5648 0.7294 0.5522 0.7969
199 0.7816 0.7181 0.5671 0.7324 0.5533 0.8022
200 0.7816 0.7227 0.5661 0.7315 0.5531 0.8079
201 0.7817 0.7185 0.5675 0.7324 0.5521 0.8690
202 0.7817 0.7210 0.5684 0.7298 0.5525 0.7974
203 0.7818 0.7182 0.5625 0.7230 0.5547 0.8165
204 0.7817 0.7162 0.5661 0.7247 0.5541 0.7929
205 0.7818 0.7203 0.5663 0.7324 0.5532 0.7970
206 0.7818 0.7203 0.5671 0.7328 0.5518 0.8442
207 0.7818 0.7207 0.5629 0.7315 0.5517 0.8474
208 0.7819 0.7187 0.5653 0.7319 0.5518 0.7731
209 0.7819 0.7162 0.5639 0.7332 0.5515 0.8190
210 0.7818 0.7203 0.5666 0.7328 0.5525 0.8100
211 0.7819 0.7161 0.5652 0.7298 0.5521 0.7896
212 0.7819 0.7194 0.5645 0.7319 0.5517 0.8099
213 0.7819 0.7251 0.5622 0.7298 0.5518 0.7966
214 0.7820 0.7229 0.5648 0.7324 0.5524 0.8033
215 0.7821 0.7204 0.5655 0.7289 0.5526 0.7878
216 0.7819 0.7160 0.5653 0.7298 0.5527 0.7978
217 0.7819 0.7177 0.5675 0.7307 0.5519 0.8145
218 0.7820 0.7215 0.5667 0.7298 0.5524 0.8193
Stopping since valid_loss has not improved in the last 10 epochs.
Fitting 5 folds for each of 64 candidates, totalling 320 fits
epoch roc_auc train_acc train_loss valid_acc valid_loss dur
------- --------- ----------- ------------ ----------- ------------ ------
1 0.7782 0.6942 0.6099 0.7046 0.5743 0.9785
2 0.7802 0.7129 0.5770 0.7234 0.5546 0.9802
3 0.7820 0.7189 0.5674 0.7332 0.5490 0.9934
4 0.7817 0.7197 0.5634 0.7285 0.5497 0.9948
5 0.7829 0.7197 0.5647 0.7315 0.5570 1.0047
6 0.7815 0.7239 0.5621 0.7285 0.5500 1.0186
7 0.7825 0.7266 0.5623 0.7264 0.5500 0.9914
8 0.7837 0.7214 0.5625 0.7217 0.5517 0.9774
9 0.7817 0.7235 0.5611 0.7272 0.5497 1.0104
10 0.7829 0.7194 0.5618 0.7277 0.5474 0.9629
11 0.7831 0.7230 0.5588 0.7401 0.5459 1.0385
12 0.7834 0.7272 0.5575 0.7234 0.5513 0.9736
13 0.7821 0.7211 0.5582 0.7328 0.5463 0.9824
14 0.7820 0.7199 0.5568 0.7349 0.5470 0.9754
15 0.7817 0.7255 0.5565 0.7324 0.5475 0.9922
16 0.7810 0.7210 0.5561 0.7268 0.5489 0.9806
17 0.7833 0.7243 0.5543 0.7349 0.5461 0.9999
18 0.7816 0.7208 0.5586 0.7319 0.5468 0.9600
19 0.7842 0.7265 0.5538 0.7349 0.5446 0.9940
20 0.7828 0.7294 0.5542 0.7345 0.5456 0.9947
21 0.7837 0.7267 0.5559 0.7362 0.5460 0.9993
22 0.7843 0.7243 0.5549 0.7366 0.5478 1.0107
23 0.7846 0.7260 0.5517 0.7354 0.5447 1.0162
24 0.7858 0.7262 0.5528 0.7358 0.5451 0.9822
25 0.7848 0.7262 0.5537 0.7332 0.5441 0.9710
26 0.7852 0.7259 0.5531 0.7319 0.5463 1.0127
27 0.7852 0.7259 0.5541 0.7371 0.5440 0.9635
28 0.7843 0.7305 0.5512 0.7336 0.5501 1.0555
29 0.7849 0.7290 0.5520 0.7328 0.5455 1.0591
30 0.7838 0.7283 0.5527 0.7341 0.5486 0.9831
31 0.7852 0.7275 0.5526 0.7345 0.5445 0.9915
32 0.7834 0.7278 0.5528 0.7328 0.5474 0.9945
33 0.7825 0.7283 0.5517 0.7366 0.5474 0.9616
34 0.7859 0.7306 0.5518 0.7375 0.5454 1.0094
35 0.7865 0.7290 0.5506 0.7341 0.5445 0.9917
36 0.7849 0.7263 0.5521 0.7324 0.5466 1.0032
Stopping since valid_loss has not improved in the last 10 epochs.
RandomizedSearchCV (MLP) took 8553.88 seconds (2.38 hours) for 64 selected candidates parameter settings.
Best params: {'nnet__batch_size': 32, 'nnet__lr': 0.02072290032105245, 'nnet__max_epochs': 53, 'nnet__module__dropout': 0.1, 'nnet__module__hidden_dim': 44, 'nnet__optimizer': <class '__main__.AdaBound'>, 'nnet__optimizer__weight_decay': 0.0008001291628848295}
Best scores: 0.7292344419008214
mlp_model_selection_2 = mlp_model_selection
Save model selection estimator to the disk
dump(mlp_model_selection, 'mlp_model_selection.joblib')
['mlp_model_selection.joblib']
# Save best model trained on entire training set
nnet.save_params(f_params='mlp_model_selection.pkl', f_optimizer='mlp_model_selection_opt.pkl',
f_history='mlp_model_selection_history.json')
#y_proba = mlp_best_model.predict_proba(X_train.astype(np.float32))
mlp_best_model.cv_results_
{'mean_fit_time': array([ 40.20006232, 126.9082253 , 90.4696847 , 92.59804678,
33.03423095, 37.45918126, 68.84802742, 427.96842856,
32.97734056, 211.92919326, 364.7022759 , 29.1058476 ,
34.82939792, 107.77825089, 223.86518731, 108.53101988,
29.8953897 , 71.48975186, 270.76789136, 86.97907391,
227.43093672, 30.61117463, 143.24132252, 31.72202902,
123.62872438, 78.9103498 , 165.70403337, 65.76558399,
107.79964752, 31.60006776, 171.44596238, 31.26067028,
59.70401416, 27.13257914, 170.68083549, 89.7373909 ,
81.08183303, 124.37984114, 110.59441633, 41.16142364,
45.66267314, 41.51212716, 60.22580981, 118.28041553,
24.73878994, 20.16117187, 57.71147466, 106.2187593 ,
96.79929786, 205.8484334 , 91.43404021, 30.22234559,
135.28341813, 142.78744173, 213.9453073 , 23.21465764,
87.2141459 , 34.36895413, 182.14707999, 270.19072375,
111.42497444, 108.69445705, 20.71821685, 31.4815814 ]),
'mean_score_time': array([0.37702832, 0.3792841 , 0.38211341, 0.36646185, 0.36619453,
0.38486366, 0.38207717, 0.34967918, 0.35587096, 0.36126442,
0.35967903, 0.36137547, 0.35899239, 0.3705955 , 0.37025037,
0.35680375, 0.36620455, 0.36679206, 0.34916029, 0.37593231,
0.3608983 , 0.36603222, 0.37645226, 0.35831976, 0.35929341,
0.35714912, 0.35917134, 0.36623869, 0.36350179, 0.36593218,
0.36694345, 0.36149192, 0.36388354, 0.36794691, 0.3775353 ,
0.35762658, 0.35410748, 0.37023277, 0.36843677, 0.36949105,
0.36904302, 0.37452264, 0.3745441 , 0.36066084, 0.36780829,
0.3594285 , 0.3581181 , 0.36536465, 0.37527213, 0.34771786,
0.35279579, 0.36390042, 0.36936378, 0.37327199, 0.36017284,
0.36078386, 0.36203046, 0.35731397, 0.34926281, 0.34677839,
0.35522213, 0.34416313, 0.36461272, 0.29571152]),
'mean_test_accuracy_score': array([0.73902883, 0.80045524, 0.69189681, 0.8045827 , 0.75104704,
0.76549317, 0.77681335, 0.7123824 , 0.68819423, 0.77787557,
0.73338392, 0.79213961, 0.79638847, 0.77569044, 0.76567527,
0.77162367, 0.62367223, 0.78342944, 0.71672231, 0.79620637,
0.75584219, 0.79811836, 0.71978756, 0.57326252, 0.80376328,
0.78952959, 0.79062215, 0.7753566 , 0.76831563, 0.59766313,
0.78555387, 0.75502276, 0.79578149, 0.73195751, 0.79101669,
0.71320182, 0.64254932, 0.73147193, 0.75186646, 0.75729894,
0.7814264 , 0.78552352, 0.77308042, 0.77720789, 0.69951442,
0.344522 , 0.76877086, 0.78409712, 0.77013657, 0.72880121,
0.80555387, 0.44518968, 0.78424886, 0.73754173, 0.77426404,
0.77672231, 0.7722003 , 0.78312595, 0.68130501, 0.76588771,
0.79280728, 0.77432473, 0.66974203, 0.77383915]),
'mean_test_average_precision_score': array([0.21702152, 0.2366917 , 0.19320427, 0.240862 , 0.2161196 ,
0.22147106, 0.22640136, 0.20248586, 0.16926606, 0.22897794,
0.20962364, 0.24064611, 0.2353301 , 0.22760791, 0.222178 ,
0.22599919, 0.18220766, 0.22996659, 0.20447597, 0.23543049,
0.21840376, 0.23786862, 0.20499473, 0.11313966, 0.23981652,
0.23367916, 0.23297341, 0.22715167, 0.22413538, 0.18764554,
0.23216406, 0.21282656, 0.23483734, 0.20356163, 0.23334036,
0.20150453, 0.17235584, 0.20905599, 0.2181583 , 0.2194903 ,
0.23121002, 0.23091181, 0.22455872, 0.22804782, 0.20522223,
0.1103386 , 0.22284012, 0.22932706, 0.22477555, 0.20788918,
0.23917651, 0.13895419, 0.23167735, 0.21199358, 0.22542644,
0.22459329, 0.22440431, 0.22876741, 0.18810705, 0.22301715,
0.23462254, 0.22742475, 0.19190078, 0.22647585]),
'mean_test_f1_score': array([0.38089663, 0.41489028, 0.3426815 , 0.4206248 , 0.38095209,
0.39007331, 0.39806419, 0.35755906, 0.2705537 , 0.40139659,
0.36985147, 0.41740053, 0.41239532, 0.39934085, 0.39093407,
0.39679297, 0.32380637, 0.40370802, 0.36070644, 0.41256282,
0.38468166, 0.41454219, 0.36181526, 0.09531169, 0.41912579,
0.40930729, 0.40848531, 0.39874121, 0.39383154, 0.32790603,
0.40677515, 0.37733662, 0.41164931, 0.36173165, 0.40917127,
0.35639435, 0.3082296 , 0.36866637, 0.38380462, 0.38619478,
0.40462062, 0.40508636, 0.39527942, 0.40023008, 0.3607349 ,
0.13573095, 0.39242299, 0.40304793, 0.39486713, 0.3669244 ,
0.4187314 , 0.24449314, 0.40582914, 0.37350542, 0.3964818 ,
0.39314923, 0.39483758, 0.4021944 , 0.33450434, 0.39200139,
0.41088725, 0.39884334, 0.33364538, 0.39769292]),
'mean_test_precision_score': array([0.26836444, 0.31119968, 0.22574702, 0.31626833, 0.26501585,
0.2759156 , 0.28813559, 0.23892307, 0.26243548, 0.28851607,
0.2521993 , 0.3152922 , 0.30781937, 0.28646396, 0.2770062 ,
0.28338999, 0.22555101, 0.29262527, 0.24195493, 0.30776363,
0.26924639, 0.32751002, 0.24376388, 0.47805106, 0.31448199,
0.29992736, 0.3000414 , 0.28600227, 0.27960661, 0.27074105,
0.2957792 , 0.26579543, 0.30532174, 0.25374922, 0.3009486 ,
0.23855021, 0.19711973, 0.25136138, 0.26723767, 0.27473886,
0.29458332, 0.29588918, 0.28275339, 0.28775683, 0.26111761,
0.27658159, 0.27925523, 0.29381927, 0.28221074, 0.24905715,
0.31649457, 0.15026233, 0.29436665, 0.25687516, 0.28402792,
0.30505946, 0.28178163, 0.29281105, 0.21888647, 0.27752482,
0.30295935, 0.28535363, 0.26326984, 0.28559325]),
'mean_test_recall_score': array([0.68263142, 0.62714427, 0.71281285, 0.62877241, 0.67941107,
0.66567496, 0.65112551, 0.71039096, 0.50311188, 0.66028376,
0.69396125, 0.64035545, 0.63012737, 0.66082248, 0.66620679,
0.6653949 , 0.67266926, 0.65086141, 0.70877407, 0.63389152,
0.6764443 , 0.60452671, 0.70311624, 0.40916442, 0.6290405 ,
0.64654511, 0.64115899, 0.65974069, 0.6670234 , 0.67840401,
0.6524685 , 0.65677573, 0.63254236, 0.66356216, 0.64196109,
0.7047324 , 0.70689345, 0.69315589, 0.68452584, 0.66753636,
0.65354522, 0.64600313, 0.65785825, 0.65839624, 0.68290604,
0.63418573, 0.66243865, 0.64600821, 0.66189267, 0.69718958,
0.62175852, 0.77745934, 0.65355139, 0.6893736 , 0.6581394 ,
0.61931704, 0.66028957, 0.6460024 , 0.71012614, 0.66890293,
0.6403442 , 0.66324836, 0.6718483 , 0.66136665]),
'mean_test_roc_auc_score': array([0.71440678, 0.72480007, 0.70102634, 0.72783809, 0.71977566,
0.72192005, 0.72194729, 0.7115125 , 0.60738691, 0.7265436 ,
0.71617484, 0.72588251, 0.72381322, 0.72554737, 0.7222538 ,
0.72525078, 0.64506202, 0.72556063, 0.71325205, 0.72535266,
0.72118194, 0.71361265, 0.71250928, 0.50165813, 0.72749303,
0.72711304, 0.72537794, 0.72488659, 0.72409934, 0.63291953,
0.7274582 , 0.71213417, 0.7245241 , 0.70210602, 0.72594956,
0.70950392, 0.67063664, 0.71474659, 0.72246941, 0.71811207,
0.72560239, 0.72461875, 0.72278253, 0.72534314, 0.69226291,
0.47096665, 0.72235367, 0.72381785, 0.72288393, 0.71500126,
0.72532243, 0.5902343 , 0.72719623, 0.71651404, 0.72357361,
0.70800866, 0.72334892, 0.72326746, 0.6938862 , 0.72355076,
0.72625257, 0.72583696, 0.67066016, 0.72474251]),
'mean_train_accuracy_score': array([0.73854325, 0.80305008, 0.69279211, 0.80685888, 0.75144917,
0.76483308, 0.77681335, 0.71253414, 0.69000759, 0.77829287,
0.73340668, 0.79210167, 0.79707132, 0.77588771, 0.76663126,
0.77160091, 0.62355842, 0.78375569, 0.71610015, 0.79649469,
0.75501517, 0.79723065, 0.71990895, 0.57288316, 0.80658574,
0.79049317, 0.79308042, 0.77547041, 0.76759484, 0.59747344,
0.78634294, 0.7598786 , 0.79737481, 0.73236722, 0.79227618,
0.71300455, 0.64128983, 0.73208649, 0.75118361, 0.7583915 ,
0.78225341, 0.78731411, 0.77360395, 0.7783915 , 0.69977997,
0.34347496, 0.76949924, 0.78638088, 0.77144917, 0.72924886,
0.8077997 , 0.44945372, 0.78570561, 0.73864188, 0.77366464,
0.77691199, 0.77193475, 0.7825569 , 0.68141882, 0.76687405,
0.79550835, 0.7756525 , 0.67088012, 0.7729742 ]),
'mean_train_average_precision_score': array([0.21497618, 0.24321269, 0.19440066, 0.2472908 , 0.21730556,
0.22281207, 0.22838262, 0.2023505 , 0.17784609, 0.22946377,
0.2103473 , 0.23735114, 0.23703206, 0.22796273, 0.2234984 ,
0.22732672, 0.18364811, 0.23114049, 0.20394416, 0.23960127,
0.21944891, 0.23557897, 0.20454476, 0.11292716, 0.24767791,
0.23623976, 0.23832404, 0.22829215, 0.22435397, 0.1840264 ,
0.23346046, 0.21915277, 0.23767992, 0.20563884, 0.23844414,
0.20148702, 0.17219929, 0.20934526, 0.21869018, 0.22046517,
0.23146855, 0.2339176 , 0.226229 , 0.22983424, 0.20525123,
0.1102228 , 0.22522387, 0.23442095, 0.22698475, 0.20878832,
0.24481112, 0.14147853, 0.23409484, 0.21249605, 0.22750731,
0.22640685, 0.22570063, 0.22963031, 0.18831966, 0.22427315,
0.23957712, 0.22956059, 0.19284159, 0.2265541 ]),
'mean_train_f1_score': array([0.37845354, 0.42338902, 0.34450655, 0.42899936, 0.38267847,
0.39179901, 0.40075213, 0.35758086, 0.2885044 , 0.40240096,
0.37098295, 0.413958 , 0.4148567 , 0.40011972, 0.39299578,
0.39855934, 0.32541765, 0.40541858, 0.36015836, 0.41772839,
0.38595831, 0.4117794 , 0.36151845, 0.09397992, 0.42939317,
0.41280283, 0.41578395, 0.40045899, 0.39422086, 0.32284285,
0.40870868, 0.3864108 , 0.41580994, 0.36462761, 0.4157295 ,
0.35653435, 0.30788214, 0.36954472, 0.3843684 , 0.38778699,
0.40546222, 0.40939221, 0.39758742, 0.40285946, 0.36100277,
0.13721967, 0.39559686, 0.40982701, 0.39813756, 0.36834762,
0.42616329, 0.25097907, 0.40935858, 0.37459906, 0.3991506 ,
0.39614332, 0.39664079, 0.40331218, 0.3349002 , 0.39399146,
0.4177504 , 0.40203161, 0.33603641, 0.39782719]),
'mean_train_precision_score': array([0.26566897, 0.31766122, 0.22709941, 0.32215989, 0.26610318,
0.27650136, 0.28942847, 0.2389592 , 0.27767528, 0.28914285,
0.25287946, 0.31240387, 0.30960471, 0.2868274 , 0.27827972,
0.28439091, 0.22711518, 0.29380402, 0.24137695, 0.3116205 ,
0.26956733, 0.32467918, 0.24339588, 0.4741081 , 0.32199885,
0.30230238, 0.30543904, 0.28662093, 0.27944867, 0.26553324,
0.29716579, 0.27285674, 0.30822034, 0.25638833, 0.30520875,
0.23851908, 0.19677478, 0.25192138, 0.26746434, 0.27504355,
0.29493366, 0.29921905, 0.28416245, 0.28945699, 0.2601101 ,
0.27747202, 0.28138984, 0.29904537, 0.28399084, 0.24993485,
0.32240709, 0.15475019, 0.29703563, 0.25741797, 0.28546989,
0.30422332, 0.28289755, 0.29319448, 0.21919705, 0.27882582,
0.30798638, 0.28750357, 0.26459674, 0.28589738]),
'mean_train_recall_score': array([0.68292323, 0.6388082 , 0.71511305, 0.64318347, 0.68271885,
0.67221107, 0.65847288, 0.71012859, 0.52279324, 0.66224377,
0.69712976, 0.63826617, 0.63536819, 0.6627156 , 0.66985581,
0.6691832 , 0.66945396, 0.65416137, 0.70925319, 0.64250934,
0.681507 , 0.60122475, 0.70332684, 0.40835017, 0.64527196,
0.65254655, 0.65261242, 0.66520848, 0.67039219, 0.67233001,
0.65497111, 0.66790362, 0.63981567, 0.66903525, 0.65402851,
0.70575144, 0.70777096, 0.6955798 , 0.68628866, 0.6718123 ,
0.65551026, 0.65194063, 0.66251399, 0.66318719, 0.68642536,
0.63846413, 0.66783479, 0.65537458, 0.6690484 , 0.70043071,
0.6315337 , 0.78529367, 0.65887629, 0.6918106 , 0.66574455,
0.62763122, 0.66459976, 0.64971877, 0.71080115, 0.67214496,
0.65039226, 0.6690462 , 0.67167941, 0.66123238]),
'mean_train_roc_auc_score': array([0.71426461, 0.73135527, 0.70253563, 0.73541101, 0.72144698,
0.72440158, 0.7251553 , 0.71148399, 0.61701865, 0.72763498,
0.7175709 , 0.72494912, 0.72648409, 0.72648574, 0.72438694,
0.72689377, 0.64359236, 0.72718496, 0.71311126, 0.72927673,
0.72292751, 0.71166961, 0.71267052, 0.50105409, 0.73616889,
0.73027669, 0.73176312, 0.72733891, 0.7251637 , 0.63014701,
0.72899651, 0.71972991, 0.72859684, 0.70472022, 0.73192833,
0.70983842, 0.67031014, 0.71615024, 0.7228558 , 0.72059855,
0.7269275 , 0.72822091, 0.72511094, 0.72810247, 0.69395073,
0.47224391, 0.72512068, 0.72919393, 0.7267494 , 0.71666912,
0.730856 , 0.59605423, 0.7303419 , 0.71819913, 0.72655504,
0.71174839, 0.72508059, 0.72457046, 0.69424456, 0.72552292,
0.73216235, 0.72911661, 0.67122927, 0.72419639]),
'param_nnet__batch_size': masked_array(data=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
32, 32, 32, 32, 32, 32, 32, 32],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value='?',
dtype=object),
'param_nnet__lr': masked_array(data=[0.030363087280807683, 0.07945821591472181,
0.0010117839439957464, 0.000977735258317486,
0.004547334157724477, 0.011353180782126507,
0.0001384573943799594, 1.3800975192197077e-05,
0.29305275974372563, 3.785227402204702e-05,
3.504650876019769e-05, 0.057242022355422285,
0.0014692347652443534, 0.00021828259540454984,
0.00022520255999117748, 0.2854601431192754,
0.2670755499678923, 0.0003543378504703912,
3.489084427199309e-05, 0.42234022608372174,
0.0009793285366373917, 0.06497391993423347,
0.00016005877496078506, 0.26547587190922456,
0.0008398254984484967, 0.13981802799450543,
0.00035155638470374623, 0.00029003276569871,
0.00019395917168733251, 0.23205722838016982,
0.1352152382314353, 0.005480470736619931,
0.006998720616346746, 0.034953202834835,
0.00027082146536393204, 0.021103876854752467,
1.738397164548146e-05, 0.11863655156430175,
0.0004932731869274647, 0.012082825645612937,
0.0013229222029814336, 0.0021874486674363864,
0.0013169750116779452, 0.14594279408871638,
0.03239827154479142, 0.9559422064058228,
0.0029175187967904756, 0.000335850145776681,
0.10789934402708082, 3.5138434536524006e-05,
0.004457856212421489, 0.12609857901253715,
0.00031811455812097975, 0.013537759936892031,
3.056263488563348e-05, 0.09876741017260972,
0.0003354039958849039, 0.026186495085634885,
2.173918085804243e-05, 3.3668931226191865e-05,
0.001013747100637488, 0.005555234816456331,
0.15502639223980508, 0.24800916805387582],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value='?',
dtype=object),
'param_nnet__max_epochs': masked_array(data=[148, 275, 274, 118, 134, 230, 126, 328, 208, 320, 236,
318, 106, 153, 272, 248, 166, 159, 306, 110, 185, 169,
108, 126, 301, 53, 180, 224, 142, 237, 149, 150, 211,
278, 343, 60, 57, 187, 139, 158, 167, 177, 82, 108,
103, 121, 69, 75, 194, 142, 178, 64, 121, 178, 308,
288, 273, 161, 127, 185, 139, 334, 70, 331],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value='?',
dtype=object),
'param_nnet__module__dropout': masked_array(data=[0.6, 0.0, 0.9, 0.1, 0.8, 0.6, 0.8, 0.2, 0.4, 0.4, 0.5,
0.2, 0.1, 0.9, 0.6, 0.6, 0.3, 0.6, 0.3, 0.1, 0.6, 0.0,
0.8, 0.9, 0.1, 0.1, 0.2, 0.6, 0.8, 0.7, 0.5, 0.4, 0.5,
0.3, 0.1, 0.8, 0.6, 0.8, 0.2, 0.8, 0.0, 0.4, 0.7, 0.6,
0.9, 0.7, 0.5, 0.1, 0.7, 0.6, 0.0, 0.9, 0.1, 0.8, 0.2,
0.0, 0.9, 0.5, 0.8, 0.5, 0.3, 0.1, 0.6, 0.2],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value='?',
dtype=object),
'param_nnet__module__hidden_dim': masked_array(data=[37, 69, 39, 75, 68, 53, 23, 56, 26, 59, 75, 37, 42, 54,
29, 47, 21, 66, 49, 72, 44, 52, 34, 62, 75, 44, 66, 75,
31, 45, 79, 25, 29, 68, 43, 73, 55, 61, 56, 73, 68, 63,
21, 21, 30, 43, 21, 73, 36, 54, 22, 37, 34, 33, 65, 38,
47, 50, 44, 39, 25, 41, 76, 51],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value='?',
dtype=object),
'param_nnet__optimizer': masked_array(data=[<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class '__main__.AdaBound'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>,
<class 'torch.optim.adam.Adam'>,
<class '__main__.AdaBound'>],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value='?',
dtype=object),
'param_nnet__optimizer__weight_decay': masked_array(data=[0.000285492476017613, 5.009032080513106e-06,
1.9878768816800636e-06, 7.538045954647015e-06,
0.004190604010683869, 6.440851380101242e-05,
0.007674007060713738, 0.00013544299070134376,
0.0013175428816538889, 0.038325988832442925,
0.00015286332925360343, 2.5912819605161014e-05,
0.022225239477743557, 0.015320184123405632,
3.75692244099265e-05, 3.00997288321983e-06,
0.00024279391489541403, 0.006594524291805665,
8.805476667129781e-05, 3.340722230503404e-05,
6.01969279860118e-05, 9.871461672235258e-05,
1.5894157764750446e-05, 0.018718129070616634,
7.168184021956474e-06, 0.0008001291628848295,
1.0918910494880818e-05, 0.0828236612161965,
0.0014584592294422608, 5.585661948867728e-05,
0.00017913582681945432, 0.04344313929394729,
1.749424338334307e-06, 0.00803507046980881,
3.6054614636344956e-06, 8.044093739827801e-06,
0.0008191349178468166, 6.097841646576797e-06,
1.0699379977403527e-05, 9.928726565697848e-05,
0.011434620128764766, 0.0006684309639782783,
0.0028424389283779566, 9.065604548710317e-05,
0.0018485208640276897, 3.309942830684619e-05,
8.868208076731682e-06, 1.8581252263907772e-06,
0.0004765745875026169, 0.0026299408570035557,
0.0002867828637544294, 1.6020301413263788e-06,
0.0004983381237830977, 7.51497668283994e-05,
5.118622104539927e-05, 0.001051237436877294,
0.0036514228635741385, 0.00032385865032664383,
1.1750587330969347e-05, 0.005286281863720477,
0.00023603337607522662, 9.285786642459022e-05,
0.00024266164480618433, 0.0035547434004666316],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value='?',
dtype=object),
'params': [{'nnet__batch_size': 32,
'nnet__lr': 0.030363087280807683,
'nnet__max_epochs': 148,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 37,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.000285492476017613},
{'nnet__batch_size': 32,
'nnet__lr': 0.07945821591472181,
'nnet__max_epochs': 275,
'nnet__module__dropout': 0.0,
'nnet__module__hidden_dim': 69,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 5.009032080513106e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.0010117839439957464,
'nnet__max_epochs': 274,
'nnet__module__dropout': 0.9,
'nnet__module__hidden_dim': 39,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 1.9878768816800636e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.000977735258317486,
'nnet__max_epochs': 118,
'nnet__module__dropout': 0.1,
'nnet__module__hidden_dim': 75,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 7.538045954647015e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.004547334157724477,
'nnet__max_epochs': 134,
'nnet__module__dropout': 0.8,
'nnet__module__hidden_dim': 68,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.004190604010683869},
{'nnet__batch_size': 32,
'nnet__lr': 0.011353180782126507,
'nnet__max_epochs': 230,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 53,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 6.440851380101242e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.0001384573943799594,
'nnet__max_epochs': 126,
'nnet__module__dropout': 0.8,
'nnet__module__hidden_dim': 23,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.007674007060713738},
{'nnet__batch_size': 32,
'nnet__lr': 1.3800975192197077e-05,
'nnet__max_epochs': 328,
'nnet__module__dropout': 0.2,
'nnet__module__hidden_dim': 56,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.00013544299070134376},
{'nnet__batch_size': 32,
'nnet__lr': 0.29305275974372563,
'nnet__max_epochs': 208,
'nnet__module__dropout': 0.4,
'nnet__module__hidden_dim': 26,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.0013175428816538889},
{'nnet__batch_size': 32,
'nnet__lr': 3.785227402204702e-05,
'nnet__max_epochs': 320,
'nnet__module__dropout': 0.4,
'nnet__module__hidden_dim': 59,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.038325988832442925},
{'nnet__batch_size': 32,
'nnet__lr': 3.504650876019769e-05,
'nnet__max_epochs': 236,
'nnet__module__dropout': 0.5,
'nnet__module__hidden_dim': 75,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.00015286332925360343},
{'nnet__batch_size': 32,
'nnet__lr': 0.057242022355422285,
'nnet__max_epochs': 318,
'nnet__module__dropout': 0.2,
'nnet__module__hidden_dim': 37,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 2.5912819605161014e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.0014692347652443534,
'nnet__max_epochs': 106,
'nnet__module__dropout': 0.1,
'nnet__module__hidden_dim': 42,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.022225239477743557},
{'nnet__batch_size': 32,
'nnet__lr': 0.00021828259540454984,
'nnet__max_epochs': 153,
'nnet__module__dropout': 0.9,
'nnet__module__hidden_dim': 54,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.015320184123405632},
{'nnet__batch_size': 32,
'nnet__lr': 0.00022520255999117748,
'nnet__max_epochs': 272,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 29,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 3.75692244099265e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.2854601431192754,
'nnet__max_epochs': 248,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 47,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 3.00997288321983e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.2670755499678923,
'nnet__max_epochs': 166,
'nnet__module__dropout': 0.3,
'nnet__module__hidden_dim': 21,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.00024279391489541403},
{'nnet__batch_size': 32,
'nnet__lr': 0.0003543378504703912,
'nnet__max_epochs': 159,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 66,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.006594524291805665},
{'nnet__batch_size': 32,
'nnet__lr': 3.489084427199309e-05,
'nnet__max_epochs': 306,
'nnet__module__dropout': 0.3,
'nnet__module__hidden_dim': 49,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 8.805476667129781e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.42234022608372174,
'nnet__max_epochs': 110,
'nnet__module__dropout': 0.1,
'nnet__module__hidden_dim': 72,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 3.340722230503404e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.0009793285366373917,
'nnet__max_epochs': 185,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 44,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 6.01969279860118e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.06497391993423347,
'nnet__max_epochs': 169,
'nnet__module__dropout': 0.0,
'nnet__module__hidden_dim': 52,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 9.871461672235258e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.00016005877496078506,
'nnet__max_epochs': 108,
'nnet__module__dropout': 0.8,
'nnet__module__hidden_dim': 34,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 1.5894157764750446e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.26547587190922456,
'nnet__max_epochs': 126,
'nnet__module__dropout': 0.9,
'nnet__module__hidden_dim': 62,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.018718129070616634},
{'nnet__batch_size': 32,
'nnet__lr': 0.0008398254984484967,
'nnet__max_epochs': 301,
'nnet__module__dropout': 0.1,
'nnet__module__hidden_dim': 75,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 7.168184021956474e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.13981802799450543,
'nnet__max_epochs': 53,
'nnet__module__dropout': 0.1,
'nnet__module__hidden_dim': 44,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.0008001291628848295},
{'nnet__batch_size': 32,
'nnet__lr': 0.00035155638470374623,
'nnet__max_epochs': 180,
'nnet__module__dropout': 0.2,
'nnet__module__hidden_dim': 66,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 1.0918910494880818e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.00029003276569871,
'nnet__max_epochs': 224,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 75,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.0828236612161965},
{'nnet__batch_size': 32,
'nnet__lr': 0.00019395917168733251,
'nnet__max_epochs': 142,
'nnet__module__dropout': 0.8,
'nnet__module__hidden_dim': 31,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.0014584592294422608},
{'nnet__batch_size': 32,
'nnet__lr': 0.23205722838016982,
'nnet__max_epochs': 237,
'nnet__module__dropout': 0.7,
'nnet__module__hidden_dim': 45,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 5.585661948867728e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.1352152382314353,
'nnet__max_epochs': 149,
'nnet__module__dropout': 0.5,
'nnet__module__hidden_dim': 79,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.00017913582681945432},
{'nnet__batch_size': 32,
'nnet__lr': 0.005480470736619931,
'nnet__max_epochs': 150,
'nnet__module__dropout': 0.4,
'nnet__module__hidden_dim': 25,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.04344313929394729},
{'nnet__batch_size': 32,
'nnet__lr': 0.006998720616346746,
'nnet__max_epochs': 211,
'nnet__module__dropout': 0.5,
'nnet__module__hidden_dim': 29,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 1.749424338334307e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.034953202834835,
'nnet__max_epochs': 278,
'nnet__module__dropout': 0.3,
'nnet__module__hidden_dim': 68,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.00803507046980881},
{'nnet__batch_size': 32,
'nnet__lr': 0.00027082146536393204,
'nnet__max_epochs': 343,
'nnet__module__dropout': 0.1,
'nnet__module__hidden_dim': 43,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 3.6054614636344956e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.021103876854752467,
'nnet__max_epochs': 60,
'nnet__module__dropout': 0.8,
'nnet__module__hidden_dim': 73,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 8.044093739827801e-06},
{'nnet__batch_size': 32,
'nnet__lr': 1.738397164548146e-05,
'nnet__max_epochs': 57,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 55,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.0008191349178468166},
{'nnet__batch_size': 32,
'nnet__lr': 0.11863655156430175,
'nnet__max_epochs': 187,
'nnet__module__dropout': 0.8,
'nnet__module__hidden_dim': 61,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 6.097841646576797e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.0004932731869274647,
'nnet__max_epochs': 139,
'nnet__module__dropout': 0.2,
'nnet__module__hidden_dim': 56,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 1.0699379977403527e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.012082825645612937,
'nnet__max_epochs': 158,
'nnet__module__dropout': 0.8,
'nnet__module__hidden_dim': 73,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 9.928726565697848e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.0013229222029814336,
'nnet__max_epochs': 167,
'nnet__module__dropout': 0.0,
'nnet__module__hidden_dim': 68,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.011434620128764766},
{'nnet__batch_size': 32,
'nnet__lr': 0.0021874486674363864,
'nnet__max_epochs': 177,
'nnet__module__dropout': 0.4,
'nnet__module__hidden_dim': 63,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.0006684309639782783},
{'nnet__batch_size': 32,
'nnet__lr': 0.0013169750116779452,
'nnet__max_epochs': 82,
'nnet__module__dropout': 0.7,
'nnet__module__hidden_dim': 21,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.0028424389283779566},
{'nnet__batch_size': 32,
'nnet__lr': 0.14594279408871638,
'nnet__max_epochs': 108,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 21,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 9.065604548710317e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.03239827154479142,
'nnet__max_epochs': 103,
'nnet__module__dropout': 0.9,
'nnet__module__hidden_dim': 30,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.0018485208640276897},
{'nnet__batch_size': 32,
'nnet__lr': 0.9559422064058228,
'nnet__max_epochs': 121,
'nnet__module__dropout': 0.7,
'nnet__module__hidden_dim': 43,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 3.309942830684619e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.0029175187967904756,
'nnet__max_epochs': 69,
'nnet__module__dropout': 0.5,
'nnet__module__hidden_dim': 21,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 8.868208076731682e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.000335850145776681,
'nnet__max_epochs': 75,
'nnet__module__dropout': 0.1,
'nnet__module__hidden_dim': 73,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 1.8581252263907772e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.10789934402708082,
'nnet__max_epochs': 194,
'nnet__module__dropout': 0.7,
'nnet__module__hidden_dim': 36,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.0004765745875026169},
{'nnet__batch_size': 32,
'nnet__lr': 3.5138434536524006e-05,
'nnet__max_epochs': 142,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 54,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.0026299408570035557},
{'nnet__batch_size': 32,
'nnet__lr': 0.004457856212421489,
'nnet__max_epochs': 178,
'nnet__module__dropout': 0.0,
'nnet__module__hidden_dim': 22,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.0002867828637544294},
{'nnet__batch_size': 32,
'nnet__lr': 0.12609857901253715,
'nnet__max_epochs': 64,
'nnet__module__dropout': 0.9,
'nnet__module__hidden_dim': 37,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 1.6020301413263788e-06},
{'nnet__batch_size': 32,
'nnet__lr': 0.00031811455812097975,
'nnet__max_epochs': 121,
'nnet__module__dropout': 0.1,
'nnet__module__hidden_dim': 34,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.0004983381237830977},
{'nnet__batch_size': 32,
'nnet__lr': 0.013537759936892031,
'nnet__max_epochs': 178,
'nnet__module__dropout': 0.8,
'nnet__module__hidden_dim': 33,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 7.51497668283994e-05},
{'nnet__batch_size': 32,
'nnet__lr': 3.056263488563348e-05,
'nnet__max_epochs': 308,
'nnet__module__dropout': 0.2,
'nnet__module__hidden_dim': 65,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 5.118622104539927e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.09876741017260972,
'nnet__max_epochs': 288,
'nnet__module__dropout': 0.0,
'nnet__module__hidden_dim': 38,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.001051237436877294},
{'nnet__batch_size': 32,
'nnet__lr': 0.0003354039958849039,
'nnet__max_epochs': 273,
'nnet__module__dropout': 0.9,
'nnet__module__hidden_dim': 47,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.0036514228635741385},
{'nnet__batch_size': 32,
'nnet__lr': 0.026186495085634885,
'nnet__max_epochs': 161,
'nnet__module__dropout': 0.5,
'nnet__module__hidden_dim': 50,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.00032385865032664383},
{'nnet__batch_size': 32,
'nnet__lr': 2.173918085804243e-05,
'nnet__max_epochs': 127,
'nnet__module__dropout': 0.8,
'nnet__module__hidden_dim': 44,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 1.1750587330969347e-05},
{'nnet__batch_size': 32,
'nnet__lr': 3.3668931226191865e-05,
'nnet__max_epochs': 185,
'nnet__module__dropout': 0.5,
'nnet__module__hidden_dim': 39,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.005286281863720477},
{'nnet__batch_size': 32,
'nnet__lr': 0.001013747100637488,
'nnet__max_epochs': 139,
'nnet__module__dropout': 0.3,
'nnet__module__hidden_dim': 25,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.00023603337607522662},
{'nnet__batch_size': 32,
'nnet__lr': 0.005555234816456331,
'nnet__max_epochs': 334,
'nnet__module__dropout': 0.1,
'nnet__module__hidden_dim': 41,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 9.285786642459022e-05},
{'nnet__batch_size': 32,
'nnet__lr': 0.15502639223980508,
'nnet__max_epochs': 70,
'nnet__module__dropout': 0.6,
'nnet__module__hidden_dim': 76,
'nnet__optimizer': torch.optim.adam.Adam,
'nnet__optimizer__weight_decay': 0.00024266164480618433},
{'nnet__batch_size': 32,
'nnet__lr': 0.24800916805387582,
'nnet__max_epochs': 331,
'nnet__module__dropout': 0.2,
'nnet__module__hidden_dim': 51,
'nnet__optimizer': __main__.AdaBound,
'nnet__optimizer__weight_decay': 0.0035547434004666316}],
'rank_test_accuracy_score': array([44, 4, 55, 2, 43, 38, 23, 53, 56, 21, 46, 10, 6, 25, 37, 32, 60,
18, 51, 7, 40, 5, 50, 62, 3, 13, 12, 26, 35, 61, 14, 41, 8, 47,
11, 52, 59, 48, 42, 39, 20, 15, 30, 22, 54, 64, 34, 17, 33, 49, 1,
63, 16, 45, 28, 24, 31, 19, 57, 36, 9, 27, 58, 29], dtype=int32),
'rank_test_average_precision_score': array([42, 6, 55, 1, 43, 38, 27, 53, 61, 20, 46, 2, 8, 23, 37, 28, 59,
18, 51, 7, 40, 5, 50, 63, 3, 11, 13, 25, 34, 58, 14, 44, 9, 52,
12, 54, 60, 47, 41, 39, 16, 17, 32, 22, 49, 64, 36, 19, 30, 48, 4,
62, 15, 45, 29, 31, 33, 21, 57, 35, 10, 24, 56, 26], dtype=int32),
'rank_test_f1_score': array([43, 5, 55, 1, 42, 38, 26, 53, 61, 21, 46, 4, 8, 23, 37, 28, 59,
18, 52, 7, 40, 6, 49, 64, 2, 11, 13, 25, 33, 58, 14, 44, 9, 50,
12, 54, 60, 47, 41, 39, 17, 16, 30, 22, 51, 63, 35, 19, 31, 48, 3,
62, 15, 45, 29, 34, 32, 20, 56, 36, 10, 24, 57, 27], dtype=int32),
'rank_test_precision_score': array([44, 7, 60, 4, 47, 40, 24, 58, 49, 23, 53, 5, 8, 26, 38, 31, 61,
22, 57, 9, 43, 2, 56, 1, 6, 15, 14, 27, 35, 42, 17, 46, 10, 52,
13, 59, 63, 54, 45, 41, 18, 16, 32, 25, 50, 39, 36, 20, 33, 55, 3,
64, 19, 51, 30, 11, 34, 21, 62, 37, 12, 29, 48, 28], dtype=int32),
'rank_test_recall_score': array([15, 59, 2, 58, 16, 25, 43, 3, 63, 34, 10, 51, 56, 32, 24, 26, 19,
44, 5, 54, 18, 62, 8, 64, 57, 45, 50, 35, 23, 17, 42, 39, 55, 27,
49, 7, 6, 11, 13, 22, 41, 47, 38, 36, 14, 53, 29, 46, 30, 9, 60,
1, 40, 12, 37, 61, 33, 48, 4, 21, 52, 28, 20, 31], dtype=int32),
'rank_test_roc_auc_score': array([45, 20, 54, 1, 39, 37, 36, 50, 61, 6, 42, 9, 26, 13, 35, 18, 59,
12, 47, 15, 38, 46, 48, 63, 2, 5, 14, 19, 24, 60, 3, 49, 23, 53,
8, 51, 58, 44, 33, 40, 11, 22, 32, 16, 56, 64, 34, 25, 31, 43, 17,
62, 4, 41, 27, 52, 29, 30, 55, 28, 7, 10, 57, 21], dtype=int32),
'split0_test_accuracy_score': array([0.78179059, 0.82261002, 0.68801214, 0.80576631, 0.77708649,
0.77359636, 0.76858877, 0.7138088 , 0.64218513, 0.7875569 ,
0.74233687, 0.81335357, 0.80455235, 0.79650986, 0.78285281,
0.79408194, 0.82200303, 0.77723824, 0.72003035, 0.81016692,
0.75326252, 0.66191199, 0.72078907, 0.86646434, 0.80242792,
0.80925645, 0.8060698 , 0.78088012, 0.76828528, 0.44172989,
0.79924127, 0.79787557, 0.79059181, 0.71942337, 0.81122914,
0.71699545, 0.62852807, 0.72078907, 0.75584219, 0.75978756,
0.75250379, 0.77192716, 0.77647951, 0.77526555, 0.75068285,
0.11259484, 0.77723824, 0.80379363, 0.79180577, 0.73444613,
0.82109256, 0.32852807, 0.77587253, 0.76858877, 0.7569044 ,
0.71487102, 0.77405159, 0.77116844, 0.66418816, 0.78270106,
0.79301973, 0.77207891, 0.88543247, 0.80030349]),
'split0_test_average_precision_score': array([0.2244501 , 0.24275992, 0.19069016, 0.23281254, 0.22223345,
0.21938465, 0.22022762, 0.19886476, 0.17895548, 0.22594646,
0.20861451, 0.24400887, 0.23176783, 0.23136608, 0.22523628,
0.22981709, 0.22355879, 0.22313437, 0.20066021, 0.23713303,
0.21369361, 0.18321448, 0.2010508 , 0.11499838, 0.22954523,
0.23674684, 0.23136792, 0.22378154, 0.21885975, 0.14444497,
0.23320353, 0.22414889, 0.22378949, 0.20000931, 0.23462328,
0.19979159, 0.17252199, 0.20070962, 0.21605784, 0.21291916,
0.21359614, 0.21940253, 0.2218012 , 0.22251891, 0.21026788,
0.11259484, 0.22076038, 0.23367038, 0.22718704, 0.20381082,
0.23813774, 0.13265676, 0.221371 , 0.21906507, 0.21485156,
0.19805223, 0.22009255, 0.22082949, 0.1826289 , 0.22431966,
0.22608305, 0.22106922, 0.22894193, 0.23282769]),
'split0_test_f1_score': array([0.39681208, 0.42611684, 0.33890675, 0.41122355, 0.39322594,
0.38902539, 0.38926712, 0.35322359, 0.31770833, 0.39965695,
0.37017804, 0.42630597, 0.40971586, 0.40794702, 0.39798065,
0.40560666, 0.40061318, 0.39438944, 0.35647018, 0.4173265 ,
0.37844037, 0.32566586, 0.35709294, 0.07172996, 0.40656335,
0.41670534, 0.40942699, 0.3958159 , 0.38748496, 0.25511237,
0.41069042, 0.39891697, 0.39737991, 0.35552457, 0.41431262,
0.3548945 , 0.30730051, 0.35664336, 0.38186708, 0.37848449,
0.37819291, 0.38877593, 0.39257732, 0.39328144, 0.37361799,
0.20240044, 0.39137645, 0.41200546, 0.40191805, 0.3627094 ,
0.42006886, 0.23456149, 0.39193084, 0.38779607, 0.38051044,
0.35229231, 0.3900041 , 0.39046079, 0.32509912, 0.39679865,
0.40070299, 0.39091646, 0.37962202, 0.41039427]),
'split0_test_precision_score': array([0.28806334, 0.33513514, 0.22255068, 0.31215084, 0.28350208,
0.27941176, 0.27692308, 0.23689052, 0.20228445, 0.29308176,
0.25537359, 0.32596291, 0.31041667, 0.30334865, 0.28929664,
0.30045425, 0.32263374, 0.28418549, 0.24047059, 0.31886121,
0.26414088, 0.20999219, 0.24103774, 0.16504854, 0.30716253,
0.31776362, 0.31153305, 0.28701456, 0.27584238, 0.15010722,
0.30671989, 0.29986431, 0.29392765, 0.23977433, 0.31837916,
0.23871568, 0.19448424, 0.2407932 , 0.26706072, 0.26703601,
0.26368953, 0.27839255, 0.28282828, 0.28251913, 0.26049973,
0.11259484, 0.28263473, 0.31091283, 0.29703608, 0.24850299,
0.33075136, 0.13455051, 0.28215768, 0.2761578 , 0.26681128,
0.23668365, 0.2801648 , 0.27886836, 0.21009066, 0.28860294,
0.29726206, 0.27958237, 0.48631579, 0.30738255]),
'split0_test_recall_score': array([0.63746631, 0.58490566, 0.71024259, 0.60242588, 0.64150943,
0.64016173, 0.65498652, 0.69407008, 0.73989218, 0.62803235,
0.67250674, 0.61590296, 0.60242588, 0.62264151, 0.63746631,
0.62398922, 0.52830189, 0.64420485, 0.68867925, 0.60377358,
0.6671159 , 0.72506739, 0.68867925, 0.0458221 , 0.60107817,
0.60512129, 0.59703504, 0.63746631, 0.6509434 , 0.8490566 ,
0.6212938 , 0.59568733, 0.61320755, 0.68733154, 0.59299191,
0.69137466, 0.73180593, 0.68733154, 0.66981132, 0.64959569,
0.66846361, 0.64420485, 0.64150943, 0.64690027, 0.66037736,
1. , 0.6361186 , 0.61051213, 0.6212938 , 0.67115903,
0.5754717 , 0.91374663, 0.64150943, 0.6509434 , 0.66307278,
0.68867925, 0.64150943, 0.6509434 , 0.71832884, 0.63477089,
0.61455526, 0.64959569, 0.31132075, 0.61725067]),
'split0_test_roc_auc_score': array([0.71878445, 0.71883792, 0.69771705, 0.71699611, 0.71789904,
0.7153442 , 0.71899463, 0.70519168, 0.68484007, 0.71791494,
0.71185186, 0.72715463, 0.71631212, 0.72060598, 0.71938295,
0.71982635, 0.69378501, 0.71916125, 0.70634373, 0.72006395,
0.71565439, 0.68948308, 0.70677122, 0.50820517, 0.71452677,
0.72013931, 0.71481369, 0.71827146, 0.71705857, 0.61955224,
0.72155661, 0.70960837, 0.71315302, 0.70541337, 0.7159556 ,
0.70581045, 0.673615 , 0.70618287, 0.71828459, 0.71168225,
0.71581525, 0.71616877, 0.71755704, 0.71922647, 0.71125913,
0.5 , 0.71563112, 0.71941475, 0.71736715, 0.70681754,
0.71386444, 0.5840108 , 0.71721505, 0.71722956, 0.71594131,
0.70343675, 0.71618905, 0.71868305, 0.68782379, 0.71812074,
0.71510937, 0.71860769, 0.63479854, 0.72039004]),
'split0_train_accuracy_score': array([0.77940061, 0.82507587, 0.68376328, 0.81445372, 0.77325493,
0.77242033, 0.76460546, 0.71043247, 0.63683612, 0.78596358,
0.74051593, 0.81232929, 0.80633536, 0.79548558, 0.78110774,
0.79468892, 0.82325493, 0.77537936, 0.71741275, 0.8111912 ,
0.75094841, 0.66278452, 0.71801973, 0.86433991, 0.81369499,
0.81028073, 0.81191199, 0.77913505, 0.76748862, 0.43937785,
0.79863429, 0.80269347, 0.78971927, 0.71756449, 0.81475721,
0.71452959, 0.62507587, 0.7189302 , 0.75269347, 0.75876328,
0.74988619, 0.76968892, 0.77348255, 0.77632777, 0.75113809,
0.11267071, 0.77564492, 0.80618361, 0.79150228, 0.73478756,
0.8262519 , 0.33156297, 0.77245827, 0.76760243, 0.75436267,
0.71221548, 0.77382398, 0.76862671, 0.65834598, 0.78118361,
0.79514416, 0.77226859, 0.88653263, 0.80113809]),
'split0_train_average_precision_score': array([0.22991833, 0.25515769, 0.19259704, 0.25528389, 0.22415662,
0.22771095, 0.22256215, 0.20238693, 0.18083527, 0.23268664,
0.21407621, 0.24789896, 0.24203149, 0.23713158, 0.22958301,
0.23720548, 0.23281399, 0.22855676, 0.20537573, 0.24770282,
0.21766217, 0.18646174, 0.20500637, 0.1140288 , 0.25788743,
0.24542265, 0.24993015, 0.23052592, 0.22463238, 0.14350522,
0.2391487 , 0.23561093, 0.23286298, 0.20588507, 0.25323858,
0.20345409, 0.17117686, 0.20539807, 0.2202675 , 0.22154571,
0.21746506, 0.22705197, 0.22629344, 0.23015106, 0.21712709,
0.11267071, 0.22725539, 0.24362693, 0.23605637, 0.21119451,
0.25503143, 0.13335569, 0.22843184, 0.22480877, 0.22029883,
0.20227476, 0.22653797, 0.22317552, 0.18205326, 0.23085329,
0.24053226, 0.22879134, 0.23924898, 0.23959669]),
'split0_train_f1_score': array([0.40316124, 0.44156473, 0.34092347, 0.43981216, 0.39497925,
0.3991988 , 0.39148769, 0.35732929, 0.31976124, 0.40772622,
0.37682216, 0.43079047, 0.42270723, 0.41484858, 0.40306228,
0.41479239, 0.41329807, 0.40076915, 0.36218854, 0.43035367,
0.38304671, 0.33019366, 0.36180991, 0.06485356, 0.44262853,
0.42747567, 0.43311228, 0.40385009, 0.39454707, 0.25356097,
0.41785479, 0.41423584, 0.40862051, 0.36285837, 0.4375072 ,
0.35930183, 0.30504184, 0.36244729, 0.38656253, 0.38920373,
0.3826201 , 0.39789745, 0.39765964, 0.40287624, 0.38241386,
0.20252301, 0.39922796, 0.42459736, 0.41282051, 0.3722726 ,
0.44159961, 0.23570747, 0.40008002, 0.39478364, 0.38689518,
0.35744537, 0.398021 , 0.39295312, 0.32367077, 0.40462428,
0.41885493, 0.40047938, 0.3968542 , 0.41884701]),
'split0_train_precision_score': array([0.2899749 , 0.34480802, 0.2227733 , 0.33327547, 0.28238529,
0.28410549, 0.27618652, 0.23823959, 0.20264793, 0.29621721,
0.25830627, 0.32721552, 0.318236 , 0.30610284, 0.29091995,
0.30551131, 0.33011467, 0.28649978, 0.24285222, 0.32599272,
0.26567592, 0.2126978 , 0.24282586, 0.14519906, 0.33384694,
0.32385082, 0.32790859, 0.29017069, 0.27918356, 0.14915617,
0.30985686, 0.31122017, 0.29907856, 0.24325875, 0.3325162 ,
0.24045584, 0.19278286, 0.24344006, 0.26825127, 0.27227523,
0.26501492, 0.28201884, 0.28388305, 0.28809386, 0.26542081,
0.11267071, 0.28585976, 0.31900491, 0.30234742, 0.25382637,
0.34613914, 0.13528182, 0.28457598, 0.27936242, 0.2691345 ,
0.23879584, 0.28425151, 0.27893175, 0.20829306, 0.2917535 ,
0.30781398, 0.28467982, 0.49472097, 0.3122314 ]),
'split0_train_recall_score': array([0.66127946, 0.61380471, 0.72592593, 0.64646465, 0.65690236,
0.67104377, 0.67205387, 0.71447811, 0.75757576, 0.65387205,
0.6962963 , 0.63030303, 0.62929293, 0.64343434, 0.65589226,
0.64579125, 0.55252525, 0.66666667, 0.71212121, 0.63299663,
0.68619529, 0.73771044, 0.70942761, 0.04175084, 0.65656566,
0.62861953, 0.63771044, 0.66397306, 0.67239057, 0.84511785,
0.64141414, 0.61919192, 0.64478114, 0.71380471, 0.63939394,
0.71043771, 0.73030303, 0.70909091, 0.69158249, 0.68215488,
0.68787879, 0.67542088, 0.66363636, 0.66969697, 0.68383838,
1. , 0.66161616, 0.63468013, 0.65050505, 0.6979798 ,
0.60976431, 0.91481481, 0.67340067, 0.67272727, 0.68787879,
0.71043771, 0.66363636, 0.66464646, 0.72558923, 0.65993266,
0.65521886, 0.67508418, 0.33131313, 0.63602694]),
'split0_train_roc_auc_score': array([0.72783939, 0.73285362, 0.70216775, 0.74112458, 0.72246571,
0.72816832, 0.72420564, 0.71219844, 0.68954034, 0.72830413,
0.72121356, 0.73287276, 0.72905433, 0.72911349, 0.72644976,
0.7296934 , 0.70507836, 0.72792504, 0.71510293, 0.73340725,
0.72268294, 0.69549053, 0.71426917, 0.50527046, 0.74510626,
0.73098356, 0.73587104, 0.72886554, 0.72597724, 0.61648795,
0.73000592, 0.72259297, 0.72645214, 0.71592331, 0.73820915,
0.71274344, 0.67100872, 0.71463524, 0.72601784, 0.72532285,
0.72281926, 0.72853985, 0.72553344, 0.72978222, 0.721761 ,
0.5 , 0.72587007, 0.7313204 , 0.72995539, 0.71872055,
0.73175261, 0.58615901, 0.72921851, 0.72618835, 0.7253417 ,
0.71143946, 0.72572583, 0.72323815, 0.68769842, 0.7282562 ,
0.73406518, 0.72984649, 0.64417303, 0.7290652 ]),
'split1_test_accuracy_score': array([0.78937785, 0.80394537, 0.67283763, 0.81259484, 0.75098634,
0.76889226, 0.82109256, 0.7138088 , 0.88421851, 0.79165402,
0.74370258, 0.83641882, 0.79590288, 0.78345979, 0.78512898,
0.78801214, 0.73474962, 0.79074355, 0.72670713, 0.78437026,
0.77025797, 0.81138088, 0.74962064, 0.11259484, 0.81213961,
0.80045524, 0.79438543, 0.79544765, 0.75887709, 0.12655539,
0.79286798, 0.75022762, 0.80485584, 0.72473445, 0.79984825,
0.71790592, 0.64597876, 0.73626707, 0.76403642, 0.82610015,
0.82913505, 0.81866464, 0.79059181, 0.79848255, 0.85705615,
0.11259484, 0.77799697, 0.79408194, 0.79742033, 0.74597876,
0.81183612, 0.11259484, 0.79301973, 0.75902883, 0.77996965,
0.88088012, 0.76494689, 0.81062215, 0.66646434, 0.78239757,
0.80789074, 0.79742033, 0.71426404, 0.76176024]),
'split1_test_average_precision_score': array([0.24267293, 0.24669357, 0.18914552, 0.25301693, 0.22351989,
0.23102809, 0.25339566, 0.2073575 , 0.21029418, 0.24414238,
0.21968058, 0.27268848, 0.24350136, 0.23914831, 0.23840605,
0.24030371, 0.20222192, 0.24212229, 0.21295417, 0.23945645,
0.23081818, 0.23346088, 0.22264383, 0.11259484, 0.25880969,
0.24961929, 0.24348376, 0.24353785, 0.22603691, 0.11346806,
0.24305029, 0.22039432, 0.25100217, 0.21326956, 0.2456362 ,
0.20848677, 0.17109789, 0.21593621, 0.22800325, 0.25542847,
0.26334087, 0.25724923, 0.2394836 , 0.24358908, 0.26505582,
0.11259484, 0.23451902, 0.23984966, 0.24480488, 0.22071983,
0.25272849, 0.11259484, 0.24571577, 0.2245975 , 0.23765426,
0.27856482, 0.22941977, 0.2484789 , 0.18442063, 0.23872411,
0.2533893 , 0.24565902, 0.21105471, 0.22955331]),
'split1_test_f1_score': array([0.42021721, 0.42781222, 0.33497841, 0.43684451, 0.39018952,
0.40251079, 0.43883865, 0.36412677, 0.34393809, 0.42238115,
0.38425082, 0.46368159, 0.42249893, 0.41492415, 0.41439206,
0.41718815, 0.36064375, 0.41985696, 0.37312913, 0.41546689,
0.40252565, 0.41284837, 0.38888889, 0.20240044, 0.44334532,
0.43048939, 0.42217484, 0.42245073, 0.39466667, 0.20365246,
0.42136499, 0.3862789 , 0.4329806 , 0.3731859 , 0.42577275,
0.36617797, 0.30668648, 0.37839771, 0.39798684, 0.44206426,
0.45180136, 0.44289044, 0.41673711, 0.42311034, 0.45612009,
0.20240044, 0.40841084, 0.41784642, 0.42432083, 0.38591343,
0.43636364, 0.20240044, 0.42447257, 0.39296636, 0.41247974,
0.467074 , 0.39984502, 0.43117593, 0.32782875, 0.41421569,
0.43633126, 0.4253121 , 0.36875629, 0.39938791]),
'split1_test_precision_score': array([0.30447942, 0.31860158, 0.2172 , 0.33011716, 0.26936891,
0.28389596, 0.33922001, 0.24280576, 0.47505938, 0.30703364,
0.26336832, 0.36750789, 0.3100189 , 0.29817325, 0.29892601,
0.3021148 , 0.24748996, 0.30519878, 0.25152511, 0.29899349,
0.28459821, 0.31781818, 0.26813075, 0.11259484, 0.33265857,
0.31716656, 0.30879601, 0.30967337, 0.27509294, 0.11347518,
0.30735931, 0.26701031, 0.32175623, 0.25092937, 0.31446945,
0.24509356, 0.19672131, 0.25754625, 0.27919609, 0.34603659,
0.35365854, 0.33856023, 0.30357143, 0.31217949, 0.3989899 ,
0.11259484, 0.29173888, 0.30648206, 0.31198478, 0.26512097,
0.32921811, 0.11259484, 0.30896806, 0.27427962, 0.29490151,
0.47058824, 0.28058728, 0.32575758, 0.21202532, 0.2971864 ,
0.32579787, 0.31246047, 0.24542615, 0.27884615]),
'split1_test_recall_score': array([0.67789757, 0.6509434 , 0.73180593, 0.64555256, 0.70754717,
0.69137466, 0.6212938 , 0.7277628 , 0.26954178, 0.67654987,
0.71024259, 0.62803235, 0.66307278, 0.6819407 , 0.67520216,
0.67385445, 0.66442049, 0.67250674, 0.72237197, 0.68059299,
0.68733154, 0.58894879, 0.70754717, 1. , 0.66442049,
0.66981132, 0.6671159 , 0.66442049, 0.69811321, 0.99191375,
0.66981132, 0.69811321, 0.66172507, 0.7277628 , 0.65902965,
0.72371968, 0.69541779, 0.71293801, 0.69272237, 0.61185984,
0.62533693, 0.64016173, 0.66442049, 0.65633423, 0.53234501,
1. , 0.68059299, 0.65633423, 0.66307278, 0.70889488,
0.64690027, 1. , 0.67789757, 0.69272237, 0.68598383,
0.46361186, 0.69541779, 0.63746631, 0.72237197, 0.68328841,
0.66037736, 0.66576819, 0.74123989, 0.70350404]),
'split1_test_roc_auc_score': array([0.74071007, 0.7371509 , 0.6985808 , 0.73967095, 0.73202256,
0.73505122, 0.73386851, 0.71990055, 0.61587554, 0.74140421,
0.7290953 , 0.74544572, 0.73791464, 0.73914067, 0.73713938,
0.73817551, 0.70404677, 0.73912615, 0.72481457, 0.73906531,
0.73405565, 0.71427604, 0.73125307, 0.5 , 0.74765142,
0.74342139, 0.73882471, 0.73824649, 0.73235004, 0.50433581,
0.73914643, 0.72747658, 0.74237074, 0.7260565 , 0.73837255,
0.72044397, 0.66756184, 0.72608255, 0.73290359, 0.73257151,
0.74016504, 0.7407375 , 0.73551052, 0.73642635, 0.71530041,
0.5 , 0.73547433, 0.73394687, 0.73876963, 0.72978944,
0.7398318 , 0.5 , 0.74276206, 0.73008212, 0.73893925,
0.69871769, 0.7345933 , 0.73502932, 0.69087135, 0.73913052,
0.74349237, 0.73994634, 0.7260406 , 0.73632794]),
'split1_train_accuracy_score': array([0.77833839, 0.80087253, 0.66972686, 0.80705615, 0.73835357,
0.7573217 , 0.81479514, 0.70709408, 0.88562215, 0.78524279,
0.73383915, 0.82765554, 0.79009863, 0.77575873, 0.77871775,
0.78247344, 0.72830046, 0.78224583, 0.71839909, 0.77553111,
0.76141882, 0.80481791, 0.74002276, 0.11267071, 0.80546282,
0.79411988, 0.78721548, 0.78846737, 0.74882398, 0.12594841,
0.78884674, 0.74419575, 0.79988619, 0.71585736, 0.79332322,
0.70963581, 0.64013657, 0.72871775, 0.75500759, 0.81828528,
0.82226859, 0.81437785, 0.78444613, 0.79355083, 0.84878604,
0.11267071, 0.77090288, 0.79059181, 0.79066768, 0.73679818,
0.80652504, 0.11267071, 0.78463581, 0.75128983, 0.77025797,
0.87530349, 0.7547041 , 0.80440061, 0.66176024, 0.77507587,
0.80231411, 0.78994689, 0.70151745, 0.75037936]),
'split1_train_average_precision_score': array([0.22682504, 0.24000695, 0.18588767, 0.24213957, 0.20973997,
0.21719199, 0.24384436, 0.19980499, 0.22725993, 0.23233139,
0.20904885, 0.25160174, 0.23368161, 0.22693983, 0.22760409,
0.23112173, 0.19512839, 0.22781275, 0.20340424, 0.22816876,
0.22177146, 0.22135897, 0.21133648, 0.11267071, 0.24319424,
0.23631633, 0.23387226, 0.23259431, 0.21559893, 0.11334795,
0.23444152, 0.21325338, 0.23873187, 0.20337053, 0.23618506,
0.19964062, 0.17020954, 0.20752652, 0.21797577, 0.24154694,
0.248621 , 0.24521327, 0.23049338, 0.2356401 , 0.24656092,
0.11267071, 0.22299594, 0.23459421, 0.2333083 , 0.21091185,
0.23971682, 0.11267071, 0.2316581 , 0.21656868, 0.22450491,
0.27025109, 0.21815899, 0.23687574, 0.18154075, 0.22684358,
0.23987033, 0.23480247, 0.19563977, 0.21507338]),
'split1_train_f1_score': array([0.39917738, 0.41929417, 0.33020465, 0.42296347, 0.37099863,
0.38354052, 0.42632197, 0.35351252, 0.37564713, 0.40716305,
0.369405 , 0.43767793, 0.4096874 , 0.39886098, 0.40020566,
0.40518672, 0.35009074, 0.40108514, 0.35981026, 0.40032431,
0.38995053, 0.39605588, 0.37329675, 0.20252301, 0.42394967,
0.41361426, 0.40939244, 0.40806794, 0.38011422, 0.20342968,
0.41038136, 0.37639878, 0.41757756, 0.35939104, 0.41331036,
0.35365648, 0.30485124, 0.36666371, 0.38413122, 0.42400192,
0.43328898, 0.42792003, 0.40477687, 0.41269156, 0.43235545,
0.20252301, 0.39312632, 0.41088581, 0.40933419, 0.37224032,
0.41992721, 0.20252301, 0.40623366, 0.38174274, 0.39488409,
0.45857355, 0.38430775, 0.41608154, 0.32331512, 0.39862055,
0.41938719, 0.4110201 , 0.34727062, 0.37971342]),
'split1_train_precision_score': array([0.28734271, 0.31224254, 0.2140008 , 0.31895962, 0.25440901,
0.26866478, 0.32743682, 0.23526134, 0.48789672, 0.29548564,
0.25196175, 0.34605598, 0.29985944, 0.2857351 , 0.2880829 ,
0.2928036 , 0.23962733, 0.29059571, 0.24185507, 0.28635639,
0.27387928, 0.30401874, 0.25624608, 0.11267071, 0.31810519,
0.30453461, 0.29783974, 0.2979845 , 0.26326028, 0.1133544 ,
0.29938176, 0.25946704, 0.31066207, 0.24088512, 0.30383154,
0.23602344, 0.19482952, 0.24876818, 0.26796168, 0.32977927,
0.33811591, 0.32778076, 0.29379562, 0.30368488, 0.37462981,
0.11267071, 0.28018908, 0.30078125, 0.30006277, 0.25451621,
0.31707317, 0.11267071, 0.29464421, 0.26512968, 0.28076158,
0.44888746, 0.26792353, 0.31348123, 0.20870076, 0.28523733,
0.3134055 , 0.30041984, 0.23040511, 0.26368159]),
'split1_train_recall_score': array([0.65353535, 0.63804714, 0.72255892, 0.62760943, 0.68484848,
0.67003367, 0.61077441, 0.71077441, 0.30538721, 0.65454545,
0.69191919, 0.5952862 , 0.64646465, 0.66026936, 0.65521886,
0.65757576, 0.64949495, 0.64713805, 0.7023569 , 0.66498316,
0.67676768, 0.56801347, 0.68720539, 1. , 0.63535354,
0.64444444, 0.65454545, 0.64713805, 0.68350168, 0.99057239,
0.65218855, 0.68518519, 0.63670034, 0.70740741, 0.64612795,
0.70505051, 0.7003367 , 0.6969697 , 0.67811448, 0.59360269,
0.6030303 , 0.61616162, 0.65050505, 0.64377104, 0.51111111,
1. , 0.65858586, 0.64814815, 0.64377104, 0.69259259,
0.62154882, 1. , 0.65387205, 0.68148148, 0.66531987,
0.46868687, 0.67946128, 0.61851852, 0.71717172, 0.66161616,
0.63367003, 0.65050505, 0.7047138 , 0.67811448]),
'split1_train_roc_auc_score': array([0.72386045, 0.7297974 , 0.69278865, 0.72872562, 0.71499799,
0.71921949, 0.72573778, 0.70870059, 0.63234303, 0.72819192,
0.71554061, 0.72622369, 0.72740077, 0.72534631, 0.72480909,
0.72795419, 0.69390096, 0.72326975, 0.71139649, 0.72727568,
0.72446763, 0.70145009, 0.71696738, 0.5 , 0.73120819,
0.72878486, 0.72930351, 0.72677552, 0.72031006, 0.50336657,
0.72919389, 0.71843697, 0.72865372, 0.71216886, 0.72907081,
0.70763427, 0.66641461, 0.71485937, 0.72144287, 0.72020879,
0.72656859, 0.72785422, 0.72597933, 0.72817026, 0.70138711,
0.5 , 0.72187523, 0.72841354, 0.72654563, 0.71750194,
0.72578082, 0.5 , 0.72755595, 0.7208177 , 0.7244513 ,
0.69781073, 0.72185975, 0.72326097, 0.68594798, 0.72554942,
0.72869906, 0.72907895, 0.7029127 , 0.71883492]),
'split2_test_accuracy_score': array([0.75037936, 0.80182094, 0.71578149, 0.784522 , 0.73915023,
0.75933232, 0.75432473, 0.71153263, 0.8676783 , 0.76327769,
0.71153263, 0.73277693, 0.75887709, 0.76145675, 0.76084977,
0.77207891, 0.18179059, 0.78437026, 0.70849772, 0.79317147,
0.77814871, 0.86707132, 0.7030349 , 0.11274659, 0.79954476,
0.77754173, 0.77496206, 0.77147193, 0.77435508, 0.87298938,
0.78543247, 0.75660091, 0.79362671, 0.66024279, 0.78588771,
0.71396055, 0.65417299, 0.71062215, 0.77116844, 0.76889226,
0.78437026, 0.79180577, 0.77389985, 0.77526555, 0.52124431,
0.88740516, 0.77283763, 0.7660091 , 0.76555387, 0.71608498,
0.79878604, 0.72261002, 0.78937785, 0.70698027, 0.77784522,
0.79241275, 0.77268589, 0.78983308, 0.69271624, 0.75508346,
0.80682853, 0.7629742 , 0.37814871, 0.75447648]),
'split2_test_average_precision_score': array([0.21008517, 0.23030225, 0.19086825, 0.22770069, 0.20429619,
0.21300676, 0.21100653, 0.19442345, 0.11627557, 0.21515805,
0.19508357, 0.20430217, 0.21459853, 0.2143608 , 0.21246458,
0.21950704, 0.10207781, 0.22515909, 0.19395716, 0.22867657,
0.22140866, 0.26906824, 0.19043525, 0.11261193, 0.23303574,
0.22216278, 0.22072951, 0.21637326, 0.22069676, 0.25128803,
0.22474389, 0.20501602, 0.22821162, 0.18109797, 0.22630194,
0.19196853, 0.17118335, 0.19234689, 0.22082949, 0.21502797,
0.22314485, 0.22595511, 0.21881326, 0.22015662, 0.15495028,
0.11259484, 0.21964068, 0.2181029 , 0.21588434, 0.19631169,
0.2294664 , 0.12944267, 0.22611996, 0.19193673, 0.22198375,
0.22437782, 0.21836594, 0.22239661, 0.18523277, 0.21184492,
0.23416495, 0.2157212 , 0.13471995, 0.21147054]),
'split2_test_f1_score': array([0.37333333, 0.40744102, 0.34257634, 0.40134907, 0.36403996,
0.37852665, 0.37514473, 0.34695981, 0.08016878, 0.38193344,
0.34785592, 0.36311031, 0.38050682, 0.38061466, 0.37805841,
0.38893409, 0.17553517, 0.39813638, 0.34593122, 0.40402274,
0.39235245, 0.45992602, 0.34041119, 0.20242805, 0.41053101,
0.39321192, 0.39096509, 0.38480392, 0.39082343, 0.43254237,
0.39778535, 0.36750789, 0.40350877, 0.3225416 , 0.39982986,
0.3438914 , 0.30750532, 0.34399725, 0.39046079, 0.38265099,
0.39557635, 0.40034965, 0.38834154, 0.39028407, 0.2742121 ,
0. , 0.38922889, 0.3861465 , 0.38323353, 0.35012157,
0.40591398, 0.22013652, 0.40017286, 0.3429738 , 0.39303483,
0.39841689, 0.38757155, 0.3954605 , 0.33190366, 0.3763524 ,
0.41309359, 0.3826087 , 0.23828996, 0.3757716 ]),
'split2_test_precision_score': array([0.26022305, 0.30711354, 0.23160892, 0.29202454, 0.2508924 ,
0.26685083, 0.26284478, 0.23282619, 0.18446602, 0.2704826 ,
0.23331799, 0.24814632, 0.26769062, 0.26893096, 0.26729911,
0.27855478, 0.09899966, 0.29030266, 0.23143508, 0.29902913,
0.28365385, 0.42386364, 0.22696629, 0.11261193, 0.30687125,
0.28375149, 0.28115771, 0.27608441, 0.28075338, 0.43519782,
0.29078456, 0.25975474, 0.29908973, 0.20795942, 0.2921069 ,
0.23181605, 0.19850922, 0.23094688, 0.27886836, 0.27362319,
0.28899938, 0.29624838, 0.27922078, 0.28097214, 0.16532594,
1. , 0.27911059, 0.2740113 , 0.27226319, 0.23584464,
0.30402685, 0.16104869, 0.29452926, 0.22940373, 0.28383234,
0.29569191, 0.27816901, 0.29244674, 0.21974661, 0.26381365,
0.31394534, 0.27069351, 0.13820612, 0.26324324]),
'split2_test_recall_score': array([0.66037736, 0.60512129, 0.65768194, 0.64150943, 0.66307278,
0.6509434 , 0.65498652, 0.68059299, 0.05121294, 0.64959569,
0.68328841, 0.67654987, 0.65768194, 0.6509434 , 0.64555256,
0.64420485, 0.77358491, 0.63342318, 0.68463612, 0.62264151,
0.6361186 , 0.50269542, 0.68059299, 1. , 0.61994609,
0.64016173, 0.64150943, 0.63477089, 0.64285714, 0.42991914,
0.62938005, 0.62803235, 0.61994609, 0.71832884, 0.63342318,
0.66576819, 0.6819407 , 0.67385445, 0.6509434 , 0.6361186 ,
0.62668464, 0.61725067, 0.63746631, 0.63881402, 0.8032345 ,
0. , 0.64285714, 0.65363881, 0.64690027, 0.67924528,
0.61051213, 0.34770889, 0.62398922, 0.67924528, 0.63881402,
0.61051213, 0.63881402, 0.61051213, 0.67789757, 0.65633423,
0.60377358, 0.65229111, 0.8638814 , 0.65633423]),
'split2_test_roc_auc_score': array([0.71108813, 0.71594984, 0.69041758, 0.72208851, 0.70593789,
0.71201411, 0.71095769, 0.69802563, 0.51124258, 0.71364873,
0.69920234, 0.70823047, 0.71469938, 0.7132111 , 0.71051568,
0.71625427, 0.44014403, 0.71847288, 0.69808071, 0.718725 ,
0.71614411, 0.70799956, 0.69323767, 0.5000855 , 0.72113926,
0.71756718, 0.71670205, 0.71179379, 0.71694841, 0.67956285,
0.71730631, 0.70047308, 0.71780478, 0.68560081, 0.71932787,
0.69292172, 0.66629525, 0.69457086, 0.71868305, 0.71092866,
0.7155311 , 0.71560208, 0.71433849, 0.71569634, 0.64434981,
0.5 , 0.71609341, 0.71695279, 0.71375451, 0.70000226,
0.71659327, 0.55894337, 0.71717587, 0.6948723 , 0.71714983,
0.7130023 , 0.71424285, 0.71154881, 0.68624701, 0.71197355,
0.71818296, 0.71465444, 0.59019993, 0.71163155]),
'split2_train_accuracy_score': array([0.75174507, 0.80299697, 0.71885432, 0.78630501, 0.74245068,
0.76282246, 0.75424886, 0.71525038, 0.86665402, 0.76384674,
0.71498483, 0.73353566, 0.75921851, 0.76164643, 0.76168437,
0.77518968, 0.1772003 , 0.78501517, 0.71107739, 0.7969651 ,
0.77807284, 0.86396055, 0.70754932, 0.11267071, 0.80235205,
0.77792109, 0.77591047, 0.77439302, 0.77325493, 0.87059939,
0.78679818, 0.76532625, 0.79723065, 0.66559181, 0.78729135,
0.717261 , 0.65443854, 0.71468134, 0.7727997 , 0.77200303,
0.78531866, 0.79681335, 0.77355842, 0.77598634, 0.52367223,
0.88732929, 0.77223065, 0.76984067, 0.76710926, 0.7185129 ,
0.80079666, 0.72579666, 0.79298179, 0.71031866, 0.77704856,
0.79578907, 0.77177542, 0.79036419, 0.69533384, 0.75652504,
0.81107739, 0.76392261, 0.37917299, 0.75508346]),
'split2_train_average_precision_score': array([0.21919413, 0.24261457, 0.20261453, 0.23706878, 0.21541256,
0.22522341, 0.22003676, 0.20519115, 0.120245 , 0.22514586,
0.20625225, 0.21451858, 0.22222839, 0.22375145, 0.22271751,
0.23001549, 0.10049597, 0.23399706, 0.20457495, 0.24069668,
0.23023774, 0.26816145, 0.20144855, 0.11267071, 0.24376471,
0.23244044, 0.23184853, 0.22853719, 0.22860867, 0.24577977,
0.23529355, 0.2214068 , 0.24219605, 0.1904786 , 0.23599372,
0.2033273 , 0.17711879, 0.20387252, 0.22917412, 0.22731358,
0.23382369, 0.24025024, 0.22694197, 0.2308012 , 0.15850861,
0.11267071, 0.22797118, 0.22893439, 0.22720441, 0.2068162 ,
0.24400513, 0.13935562, 0.23841569, 0.20308866, 0.23048272,
0.23486421, 0.22794228, 0.23244602, 0.19502568, 0.22131903,
0.24924689, 0.22539186, 0.13743521, 0.22171164]),
'split2_train_f1_score': array([0.38507799, 0.4228076 , 0.35885457, 0.41304574, 0.37880867,
0.39442077, 0.38655303, 0.36162613, 0.10764153, 0.39451415,
0.36292716, 0.37619893, 0.39012203, 0.39241853, 0.39116108,
0.4025005 , 0.17113158, 0.40913356, 0.36021505, 0.41939683,
0.40330477, 0.4592883 , 0.35570414, 0.20252301, 0.42405483,
0.40592653, 0.40483627, 0.40056446, 0.40044137, 0.42546741,
0.41102494, 0.39018139, 0.42122361, 0.33586981, 0.41195595,
0.35954284, 0.31629513, 0.35986041, 0.4010401 , 0.39863918,
0.40898172, 0.41883681, 0.39846821, 0.4035956 , 0.27963282,
0. , 0.3994799 , 0.40019773, 0.39760573, 0.36418166,
0.424043 , 0.24771024, 0.41592636, 0.35821146, 0.40341082,
0.41214372, 0.39936102, 0.40822446, 0.34563676, 0.38852896,
0.43215507, 0.39482641, 0.24274675, 0.38875213]),
'split2_train_precision_score': array([0.26707508, 0.31557989, 0.24147165, 0.29907952, 0.26008292,
0.27685613, 0.26890646, 0.2419208 , 0.21878225, 0.27739023,
0.24254789, 0.25548854, 0.27295953, 0.27526794, 0.27463255,
0.28727691, 0.0965211 , 0.29632986, 0.23998209, 0.309379 ,
0.28928885, 0.41589295, 0.23657588, 0.11267071, 0.3156682 ,
0.29052876, 0.28885694, 0.28585815, 0.2851836 , 0.42568251,
0.2983871 , 0.27585726, 0.31045491, 0.21634475, 0.29916222,
0.24137533, 0.20351589, 0.24080191, 0.28524683, 0.28359909,
0.29644209, 0.30899776, 0.2843377 , 0.28827009, 0.16853389,
1. , 0.28414912, 0.28327502, 0.28057056, 0.24425287,
0.31446234, 0.17927086, 0.30487996, 0.23868728, 0.28876617,
0.30499434, 0.28384899, 0.29930905, 0.22799097, 0.27092745,
0.32672414, 0.27758786, 0.14071134, 0.27041623]),
'split2_train_recall_score': array([0.68989899, 0.64040404, 0.6983165 , 0.66734007, 0.6969697 ,
0.68552189, 0.68720539, 0.71582492, 0.07138047, 0.68282828,
0.72053872, 0.71313131, 0.68350168, 0.68316498, 0.67946128,
0.67205387, 0.75387205, 0.66060606, 0.72188552, 0.65084175,
0.66565657, 0.51279461, 0.71649832, 1. , 0.64579125,
0.67340067, 0.67643098, 0.66902357, 0.67205387, 0.42525253,
0.66026936, 0.66632997, 0.65488215, 0.75050505, 0.66127946,
0.7043771 , 0.70942761, 0.71178451, 0.67508418, 0.67070707,
0.65925926, 0.64983165, 0.66565657, 0.67272727, 0.82053872,
0. , 0.67239057, 0.68148148, 0.68215488, 0.71548822,
0.65084175, 0.4006734 , 0.65420875, 0.71750842, 0.66902357,
0.63535354, 0.67340067, 0.64175084, 0.71414141, 0.68653199,
0.63804714, 0.68350168, 0.88316498, 0.69124579]),
'split2_train_roc_auc_score': array([0.72474855, 0.73202331, 0.70988933, 0.73437546, 0.72259772,
0.72907988, 0.72498363, 0.71550117, 0.51950811, 0.72848126,
0.71740916, 0.72462893, 0.72616726, 0.72738839, 0.72579306,
0.73016973, 0.42892406, 0.73070919, 0.71579526, 0.7331806 ,
0.72900186, 0.71067264, 0.71145566, 0.5 , 0.73401148,
0.73229675, 0.73248654, 0.72839806, 0.72907952, 0.67620044,
0.73156692, 0.72211325, 0.73509392, 0.7026574 , 0.73228573,
0.71163703, 0.67844189, 0.71341684, 0.73014576, 0.7277862 ,
0.73029231, 0.73265417, 0.72645804, 0.73091259, 0.65325782,
0.5 , 0.72864933, 0.73127088, 0.73002571, 0.71719259,
0.73533964, 0.58387667, 0.73240579, 0.71345707, 0.72989443,
0.72575714, 0.72883373, 0.72549278, 0.70354356, 0.72597228,
0.73554772, 0.72881796, 0.5991712 , 0.72721759]),
'split3_test_accuracy_score': array([0.79408194, 0.77025797, 0.68694992, 0.81790592, 0.74081942,
0.76783005, 0.75417299, 0.72200303, 0.21790592, 0.78224583,
0.73474962, 0.84810319, 0.83080425, 0.77814871, 0.7552352 ,
0.74446131, 0.62094082, 0.78482549, 0.72898331, 0.76646434,
0.73019727, 0.82670713, 0.72776935, 0.88725341, 0.79681335,
0.77769347, 0.79878604, 0.77966616, 0.78907436, 0.87481032,
0.77738998, 0.73960546, 0.81532625, 0.73095599, 0.77314112,
0.72063733, 0.64400607, 0.76889226, 0.71350531, 0.72716237,
0.78330804, 0.76646434, 0.75766313, 0.78270106, 0.57632777,
0.49726859, 0.7400607 , 0.76039454, 0.76874052, 0.72655539,
0.7861912 , 0.47253414, 0.79286798, 0.75781487, 0.76813354,
0.72139605, 0.7707132 , 0.75766313, 0.70318665, 0.76054628,
0.78694992, 0.77905918, 0.67496206, 0.75402124]),
'split3_test_average_precision_score': array([0.24238529, 0.23479443, 0.20107437, 0.26145826, 0.22234648,
0.23224244, 0.2275008 , 0.21527732, 0.12204636, 0.24396041,
0.2201318 , 0.27835572, 0.26244014, 0.24155316, 0.22897818,
0.22574792, 0.18116666, 0.24064367, 0.21926965, 0.23245296,
0.21816998, 0.26392544, 0.21640521, 0.11274659, 0.24898034,
0.23753171, 0.25244363, 0.24150757, 0.24535498, 0.24074392,
0.24055665, 0.21755077, 0.25660095, 0.20192944, 0.23732792,
0.21275323, 0.17624058, 0.2365905 , 0.21237256, 0.21641591,
0.24315471, 0.2344313 , 0.2271425 , 0.24184344, 0.17032274,
0.10116191, 0.22113933, 0.23133121, 0.23608026, 0.21678518,
0.24340477, 0.151547 , 0.24855227, 0.23072483, 0.23246001,
0.21001181, 0.23672913, 0.22829954, 0.20162852, 0.23182667,
0.24527661, 0.24267403, 0.19501153, 0.231245 ]),
'split3_test_f1_score': array([0.420828 , 0.40720439, 0.3522763 , 0.44751381, 0.38693467,
0.40374123, 0.39552239, 0.37517053, 0.2171932 , 0.42020202,
0.38320395, 0.47176781, 0.45100935, 0.41660016, 0.39745984,
0.3916185 , 0.31860338, 0.41694079, 0.38115038, 0.40371949,
0.38005579, 0.45201536, 0.37751561, 0. , 0.42899787,
0.41188278, 0.43333333, 0.41686747, 0.42323651, 0.41447835,
0.4153049 , 0.38095238, 0.44148692, 0.35969664, 0.41072132,
0.37188673, 0.31403509, 0.40900272, 0.37024683, 0.37742382,
0.4195122 , 0.40602084, 0.39576239, 0.41788618, 0.2998997 ,
0.0712083 , 0.3853606 , 0.4012135 , 0.40838509, 0.37776243,
0.42040313, 0.26697596, 0.42767296, 0.4 , 0.40405616,
0.36863824, 0.40953497, 0.39713099, 0.35530653, 0.40181956,
0.42269737, 0.41806555, 0.34294479, 0.39985191]),
'split3_test_precision_score': array([0.308125 , 0.28713418, 0.22972973, 0.34009797, 0.2638277 ,
0.28414701, 0.273619 , 0.25125628, 0.12241055, 0.30023095,
0.25968436, 0.38802083, 0.35559006, 0.29608622, 0.27507756,
0.26765432, 0.19979473, 0.30017762, 0.25664956, 0.28346028,
0.25647059, 0.35123043, 0.25432445, 1. , 0.31398252,
0.29347826, 0.31747026, 0.29708071, 0.30593881, 0.43843844,
0.29501699, 0.26022671, 0.33495822, 0.24580454, 0.29041249,
0.24908592, 0.20059768, 0.28735005, 0.24611973, 0.25407925,
0.30052417, 0.28463203, 0.27526316, 0.29935935, 0.18428351,
0.04497167, 0.26272016, 0.27930306, 0.28696127, 0.2540641 ,
0.30272512, 0.15828957, 0.31059683, 0.27751695, 0.28445909,
0.24757506, 0.28854626, 0.27597062, 0.23526844, 0.27968338,
0.30432208, 0.29732803, 0.22208979, 0.27579162]),
'split3_test_recall_score': array([0.66352624, 0.69986541, 0.75504711, 0.65410498, 0.72543742,
0.69717362, 0.71332436, 0.74024226, 0.96231494, 0.69986541,
0.730821 , 0.60161507, 0.61641992, 0.7025572 , 0.71601615,
0.7294751 , 0.78600269, 0.68236878, 0.74024226, 0.70121131,
0.73351279, 0.63391655, 0.73216689, 0. , 0.6769852 ,
0.69044415, 0.68236878, 0.69851952, 0.68640646, 0.39300135,
0.70121131, 0.71063257, 0.6473755 , 0.67025572, 0.70121131,
0.73351279, 0.72274563, 0.70928668, 0.74697174, 0.73351279,
0.69448183, 0.70794078, 0.7039031 , 0.69179004, 0.80484522,
0.17092867, 0.72274563, 0.71197847, 0.70794078, 0.73620458,
0.68775236, 0.85195155, 0.68640646, 0.71601615, 0.69717362,
0.72139973, 0.70524899, 0.70794078, 0.72543742, 0.71332436,
0.69179004, 0.7039031 , 0.75235532, 0.72678331]),
'split3_test_roc_auc_score': array([0.73709919, 0.73953421, 0.71667183, 0.74641285, 0.73410574,
0.73699112, 0.73634407, 0.72996378, 0.54281302, 0.74628981,
0.73303492, 0.74052021, 0.73723339, 0.7451558 , 0.73811753,
0.73792038, 0.69298424, 0.74010691, 0.73389743, 0.7379838 ,
0.73164437, 0.74256115, 0.72968871, 0.5 , 0.74451278,
0.73961236, 0.74797419, 0.74424864, 0.7442636 , 0.66451846,
0.7441408 , 0.72695986, 0.74202194, 0.70446256, 0.74174641,
0.72625699, 0.67837298, 0.74287662, 0.72811217, 0.72993409,
0.74453867, 0.74092096, 0.73419885, 0.74302175, 0.67606722,
0.35483324, 0.73250331, 0.73926271, 0.74220367, 0.7307669 ,
0.74322627, 0.63813586, 0.74640145, 0.73957127, 0.73716215,
0.72139766, 0.74214049, 0.73596115, 0.71289829, 0.73993565,
0.74541614, 0.74625632, 0.70874137, 0.74213289]),
'split3_train_accuracy_score': array([0.79074355, 0.76828528, 0.68179059, 0.81452959, 0.74028832,
0.7629742 , 0.75318665, 0.71475721, 0.22048558, 0.77943854,
0.72704856, 0.84506829, 0.82795903, 0.77435508, 0.75436267,
0.73497724, 0.61798179, 0.78186646, 0.72067527, 0.76157056,
0.72408953, 0.82048558, 0.72135812, 0.88736722, 0.79229894,
0.7754173 , 0.79559939, 0.77408953, 0.78550835, 0.87101669,
0.77325493, 0.73638088, 0.81062215, 0.73125948, 0.76710926,
0.71293627, 0.6314871 , 0.7624431 , 0.70694234, 0.72496206,
0.77955235, 0.76361912, 0.75588012, 0.77716237, 0.57393778,
0.49207132, 0.73960546, 0.75599393, 0.765478 , 0.72010622,
0.78277693, 0.47606222, 0.79116085, 0.75421093, 0.76604704,
0.71798179, 0.76911988, 0.75128983, 0.69578907, 0.759522 ,
0.78327011, 0.77591047, 0.67185129, 0.74965857]),
'split3_train_average_precision_score': array([0.22769623, 0.22449964, 0.19024971, 0.25022901, 0.21194345,
0.21762952, 0.21719646, 0.20068837, 0.12234595, 0.22733129,
0.20460704, 0.25773339, 0.24688556, 0.22384939, 0.21634438,
0.20915634, 0.17486381, 0.22732425, 0.20313222, 0.2203372 ,
0.20484713, 0.24100725, 0.20288813, 0.11263278, 0.23921283,
0.22599532, 0.23491333, 0.22543911, 0.22987942, 0.22479653,
0.22425288, 0.20736496, 0.23950871, 0.19545012, 0.2224231 ,
0.19935372, 0.16583858, 0.21966969, 0.19914826, 0.20574828,
0.22661326, 0.22045151, 0.2164595 , 0.22586504, 0.16586961,
0.1004691 , 0.2105513 , 0.21803913, 0.2221848 , 0.20283565,
0.23063996, 0.14938861, 0.23349266, 0.21662371, 0.22140934,
0.1991939 , 0.22185698, 0.21452007, 0.18947932, 0.21889647,
0.23223711, 0.22734679, 0.18827572, 0.21490651]),
'split3_train_f1_score': array([0.40238353, 0.39452815, 0.33754541, 0.43394697, 0.37410861,
0.38503937, 0.38284955, 0.35575358, 0.21769588, 0.4 ,
0.36265391, 0.44691224, 0.43191783, 0.39479039, 0.38197957,
0.36972212, 0.30970661, 0.4004171 , 0.3597948 , 0.38820208,
0.36252082, 0.42362972, 0.35957799, 0. , 0.41674656,
0.3976394 , 0.41217543, 0.39671766, 0.40421496, 0.39111748,
0.39510171, 0.36764037, 0.42034371, 0.35092542, 0.39175666,
0.35374498, 0.29751229, 0.38751956, 0.35263555, 0.36381186,
0.39913142, 0.38869813, 0.38237835, 0.39778552, 0.29342561,
0.0785906 , 0.3722334 , 0.38437979, 0.39117589, 0.35932615,
0.40465793, 0.26384521, 0.40965147, 0.38230527, 0.39031142,
0.35423905, 0.3914 , 0.37916667, 0.33820253, 0.38605327,
0.40668813, 0.39938993, 0.33369281, 0.37938493]),
'split3_train_precision_score': array([0.29659799, 0.27953364, 0.22046838, 0.33062809, 0.25674489,
0.27200668, 0.266473 , 0.23856585, 0.12271966, 0.28834995,
0.24603365, 0.37372593, 0.34383726, 0.28284006, 0.26651572,
0.25249538, 0.19442293, 0.28998641, 0.24249883, 0.2730011 ,
0.2450237 , 0.33180691, 0.24258824, 1. , 0.30476784,
0.28488118, 0.30482492, 0.28368589, 0.29412667, 0.41759082,
0.28240741, 0.25187032, 0.32075137, 0.2410321 , 0.27751263,
0.23695652, 0.18942812, 0.27305307, 0.23471664, 0.24599502,
0.28797374, 0.27422481, 0.26738255, 0.28592483, 0.18041157,
0.04938592, 0.25549278, 0.26848509, 0.27640919, 0.24207324,
0.29267559, 0.15672492, 0.30050346, 0.26662234, 0.27623846,
0.23867494, 0.27833879, 0.26373337, 0.22398338, 0.27093529,
0.29399399, 0.28604719, 0.21631879, 0.2631785 ]),
'split3_train_recall_score': array([0.62546312, 0.67025935, 0.71977097, 0.63118895, 0.68912092,
0.65880768, 0.67969013, 0.69922533, 0.96295049, 0.65274503,
0.68945773, 0.55574267, 0.58066689, 0.65341866, 0.6739643 ,
0.69013136, 0.76086224, 0.64668238, 0.69686763, 0.6716066 ,
0.69653082, 0.5857191 , 0.69450994, 0. , 0.65880768,
0.65813405, 0.63624116, 0.65948131, 0.64600876, 0.36780061,
0.65746042, 0.68036376, 0.60963287, 0.64499832, 0.66588077,
0.69754126, 0.69282587, 0.66722802, 0.70865611, 0.69821489,
0.65005052, 0.66722802, 0.67093297, 0.65341866, 0.78544965,
0.19232065, 0.68541596, 0.67632199, 0.66891209, 0.69686763,
0.65543954, 0.83361401, 0.64331425, 0.67531155, 0.66487033,
0.68676322, 0.65914449, 0.67430111, 0.69013136, 0.67126979,
0.65948131, 0.66150219, 0.72953857, 0.67935332]),
'split3_train_roc_auc_score': array([0.71859279, 0.72549349, 0.69837037, 0.73449491, 0.71795193,
0.71750183, 0.72110281, 0.70797699, 0.54459781, 0.72413234,
0.71063883, 0.71876741, 0.72000725, 0.72156205, 0.71926593,
0.71540042, 0.68035417, 0.72285383, 0.71028239, 0.72229811,
0.71205918, 0.7180017 , 0.70963794, 0.5 , 0.73402528,
0.72421901, 0.72603388, 0.72405898, 0.72461183, 0.65134505,
0.72270653, 0.71192742, 0.72288321, 0.69360343, 0.72291944,
0.7062158 , 0.65826365, 0.72087834, 0.70769046, 0.71328597,
0.72302022, 0.72154099, 0.71879768, 0.72314386, 0.6662702 ,
0.36121953, 0.71594983, 0.72121431, 0.72332356, 0.70996175,
0.72718965, 0.63214624, 0.72662057, 0.71976855, 0.72187982,
0.70435378, 0.72111173, 0.71768153, 0.69331928, 0.72099679,
0.72923191, 0.7259672 , 0.69703383, 0.71896784]),
'split4_test_accuracy_score': array([0.57951442, 0.80364188, 0.69590288, 0.80212443, 0.74719272,
0.75781487, 0.78588771, 0.70075873, 0.82898331, 0.7646434 ,
0.73459788, 0.73004552, 0.79180577, 0.75887709, 0.74430956,
0.75948407, 0.75887709, 0.77996965, 0.69939302, 0.82685888,
0.74734446, 0.82352049, 0.69772382, 0.88725341, 0.80789074,
0.78270106, 0.77890744, 0.74931715, 0.75098634, 0.67223065,
0.77283763, 0.73080425, 0.77450683, 0.82443096, 0.78497724,
0.69650986, 0.6400607 , 0.72078907, 0.75477997, 0.70455235,
0.75781487, 0.77875569, 0.76676783, 0.75432473, 0.792261 ,
0.11274659, 0.77572079, 0.79620637, 0.72716237, 0.72094082,
0.80986343, 0.58968134, 0.77010622, 0.6952959 , 0.78846737,
0.77405159, 0.77860395, 0.78634294, 0.67996965, 0.74871017,
0.7693475 , 0.76009105, 0.69590288, 0.79863429]),
'split4_test_average_precision_score': array([0.16551411, 0.22890832, 0.19424306, 0.22932159, 0.20820201,
0.21169336, 0.21987621, 0.19650627, 0.21875871, 0.21568242,
0.20460773, 0.20387529, 0.22434261, 0.21161117, 0.20580491,
0.21462018, 0.2020131 , 0.21877354, 0.19553866, 0.23943345,
0.20792839, 0.23967406, 0.19443854, 0.11274659, 0.22871161,
0.22233516, 0.21684223, 0.21055813, 0.20972852, 0.18828272,
0.21926594, 0.19702283, 0.21458247, 0.22150186, 0.22281248,
0.19452253, 0.17073538, 0.19969671, 0.21352836, 0.19765998,
0.21281351, 0.21752085, 0.21555306, 0.21213104, 0.22551444,
0.11274659, 0.21814121, 0.22368116, 0.19992125, 0.20181835,
0.23214514, 0.16852967, 0.21662774, 0.19364376, 0.22018261,
0.21195978, 0.21741418, 0.22383253, 0.18662443, 0.20837039,
0.2141988 , 0.21200029, 0.18977581, 0.22728271]),
'split4_test_f1_score': array([0.29329253, 0.40587695, 0.34466972, 0.40619308, 0.37037037,
0.3765625 , 0.39154808, 0.34831461, 0.39376009, 0.38280939,
0.36376864, 0.36213697, 0.39824561, 0.37661828, 0.36677941,
0.38061743, 0.36363636, 0.38921651, 0.3468513 , 0.42227848,
0.37003405, 0.42225534, 0.34516765, 0. , 0.40619137,
0.39424704, 0.38652632, 0.37376801, 0.37294612, 0.3337446 ,
0.38873009, 0.35302699, 0.38289037, 0.39770953, 0.3952198 ,
0.34512115, 0.30562061, 0.35529082, 0.37846154, 0.35035035,
0.37802027, 0.38739496, 0.38297872, 0.37658837, 0.39982464,
0.20264557, 0.38773819, 0.39802779, 0.35647817, 0.35811518,
0.41090738, 0.29839128, 0.38489647, 0.34379085, 0.39232781,
0.37932472, 0.38723226, 0.39674379, 0.33238367, 0.37082067,
0.38161107, 0.3773139 , 0.33861386, 0.40305893]),
'split4_test_precision_score': array([0.1809314 , 0.30801394, 0.22764579, 0.30695114, 0.25748818,
0.26527243, 0.28807107, 0.23083662, 0.32795699, 0.27175141,
0.24925224, 0.24682307, 0.29538061, 0.26578073, 0.2544317 ,
0.26817181, 0.25883694, 0.2832618 , 0.22969432, 0.33847403,
0.25736842, 0.33464567, 0.22836016, 1. , 0.31173506,
0.28747687, 0.28125 , 0.26015831, 0.26040555, 0.21648659,
0.27901524, 0.24212106, 0.27687688, 0.32427844, 0.289375 ,
0.22803981, 0.1952862 , 0.24017054, 0.26494346, 0.23291925,
0.26604498, 0.28161271, 0.2728833 , 0.26375405, 0.29648895,
0.11274659, 0.28007181, 0.2983871 , 0.24280839, 0.24175306,
0.31575145, 0.18482803, 0.2755814 , 0.2270177 , 0.2901354 ,
0.27475845, 0.28144078, 0.29101194, 0.21730132, 0.25833774,
0.27346939, 0.26670379, 0.22431132, 0.3027027 ]),
'split4_test_recall_score': array([0.77388964, 0.5948856 , 0.70928668, 0.60026918, 0.65948856,
0.6487214 , 0.61103634, 0.70928668, 0.49259758, 0.6473755 ,
0.67294751, 0.67967699, 0.61103634, 0.64602961, 0.65679677,
0.65545087, 0.61103634, 0.6218035 , 0.70794078, 0.56123822,
0.65814266, 0.57200538, 0.70659489, 0. , 0.58277254,
0.62718708, 0.61776581, 0.66352624, 0.65679677, 0.72812921,
0.64064603, 0.65141319, 0.6204576 , 0.5141319 , 0.62314939,
0.70928668, 0.7025572 , 0.68236878, 0.66218035, 0.70659489,
0.65275908, 0.6204576 , 0.64199192, 0.65814266, 0.61372813,
1. , 0.62987887, 0.59757739, 0.67025572, 0.69044415,
0.58815612, 0.77388964, 0.63795424, 0.70794078, 0.60565276,
0.61238223, 0.6204576 , 0.62314939, 0.70659489, 0.65679677,
0.63122476, 0.64468371, 0.69044415, 0.60296097]),
'split4_test_roc_auc_score': array([0.66435204, 0.71252746, 0.70174442, 0.71402205, 0.70891308,
0.71019959, 0.70957153, 0.70448086, 0.68216333, 0.71346029,
0.70768976, 0.70806151, 0.71290657, 0.70962332, 0.70611345,
0.71407741, 0.69435005, 0.71093596, 0.7031238 , 0.71092525,
0.70841116, 0.71374341, 0.70159572, 0.5 , 0.70963495,
0.71482494, 0.70857506, 0.71187258, 0.70987607, 0.69662831,
0.71514087, 0.69615298, 0.70727002, 0.68899685, 0.71434535,
0.70208647, 0.66733812, 0.70402003, 0.71436365, 0.70544384,
0.71196189, 0.70966441, 0.71230775, 0.71234481, 0.71433798,
0.5 , 0.71206617, 0.70951214, 0.70232471, 0.70763015,
0.71309636, 0.67008147, 0.71242675, 0.70081493, 0.70867553,
0.70348888, 0.7095789 , 0.71511497, 0.69159058, 0.70859336,
0.70906201, 0.70972 , 0.69352035, 0.7132301 ]),
'split4_train_accuracy_score': array([0.59248862, 0.81801973, 0.70982549, 0.81194992, 0.76289833,
0.76862671, 0.79723065, 0.71513657, 0.84044006, 0.77697269,
0.75064492, 0.74191958, 0.80174507, 0.77219272, 0.75728376,
0.77067527, 0.77105463, 0.79427162, 0.71293627, 0.83721548,
0.76054628, 0.8341047 , 0.71259484, 0.88736722, 0.81911988,
0.79472686, 0.7947648 , 0.76126707, 0.76289833, 0.68042489,
0.78418058, 0.75079666, 0.78941578, 0.83156297, 0.79889985,
0.71066009, 0.65531108, 0.73566009, 0.76847496, 0.71794385,
0.77424127, 0.79207132, 0.7806525 , 0.7689302 , 0.80136571,
0.11263278, 0.78911229, 0.80929439, 0.74248862, 0.73603945,
0.82264795, 0.60117602, 0.78729135, 0.70978756, 0.80060698,
0.78327011, 0.79025038, 0.79810319, 0.69586495, 0.76206373,
0.78573596, 0.77621396, 0.71532625, 0.80861153]),
'split4_train_average_precision_score': array([0.17124719, 0.2537846 , 0.20065434, 0.25173274, 0.22527522,
0.22630449, 0.23827336, 0.20368107, 0.2385443 , 0.22982369,
0.21775216, 0.21500304, 0.24033326, 0.22814138, 0.22124299,
0.22913457, 0.21493836, 0.23801166, 0.20323366, 0.26110089,
0.22272607, 0.26090547, 0.20204427, 0.11263278, 0.25433034,
0.24102404, 0.24105594, 0.22436421, 0.22305046, 0.19270251,
0.23416567, 0.21812776, 0.23510001, 0.23300987, 0.24438021,
0.20165938, 0.1766527 , 0.21025951, 0.22688526, 0.20617135,
0.23081975, 0.23662101, 0.23095671, 0.22671377, 0.23818994,
0.11263278, 0.23734553, 0.24691009, 0.21616987, 0.21218341,
0.25466227, 0.17262201, 0.2384759 , 0.20139042, 0.24084073,
0.22545029, 0.23400694, 0.24113421, 0.19349929, 0.22345336,
0.235999 , 0.23147051, 0.2036083 , 0.24148226]),
'split4_train_f1_score': array([0.30246753, 0.43875044, 0.35500464, 0.43522844, 0.39449719,
0.39679557, 0.41654841, 0.35968278, 0.42177619, 0.40260136,
0.38310652, 0.3782104 , 0.41984902, 0.3996801 , 0.38857034,
0.40059494, 0.38286123, 0.41568796, 0.35878315, 0.45036506,
0.39096874, 0.44972946, 0.35720346, 0. , 0.43958627,
0.4193583 , 0.41940331, 0.3930948 , 0.39178669, 0.3406387 ,
0.4091806 , 0.38359764, 0.41128434, 0.41409343, 0.42411733,
0.35642562, 0.3157102 , 0.37123263, 0.3974726 , 0.36327824,
0.40328888, 0.41360864, 0.40465404, 0.39734837, 0.41718611,
0.20246173, 0.41391671, 0.42907439, 0.37975146, 0.37371737,
0.44058873, 0.30510939, 0.41490139, 0.35595218, 0.42025149,
0.3983149 , 0.41011416, 0.42013511, 0.34367581, 0.39213026,
0.41166667, 0.4044422 , 0.35961768, 0.42243847]),
'split4_train_precision_score': array([0.1873542 , 0.33614199, 0.2367829 , 0.32885675, 0.27689378,
0.2808737 , 0.30813953, 0.2408084 , 0.35632985, 0.28827125,
0.26554775, 0.25953337, 0.31313131, 0.28419107, 0.2712475 ,
0.28386734, 0.27488987, 0.30560837, 0.23969656, 0.36337329,
0.2739689 , 0.35897951, 0.23874334, 1. , 0.33760607,
0.30771654, 0.307765 , 0.27540541, 0.27548926, 0.22188233,
0.2957958 , 0.26586889, 0.3001548 , 0.34042092, 0.31302117,
0.23778428, 0.20331749, 0.2535437 , 0.28114525, 0.24356913,
0.28712164, 0.30307306, 0.29141332, 0.28131129, 0.31155445,
0.11263278, 0.30125844, 0.3236806 , 0.26056426, 0.25500553,
0.34168523, 0.18980263, 0.30057454, 0.23728814, 0.31244875,
0.28976402, 0.30012492, 0.31051699, 0.22701709, 0.27527555,
0.29799427, 0.28878316, 0.24082752, 0.31997919]),
'split4_train_recall_score': array([0.78443921, 0.63152577, 0.70899293, 0.64331425, 0.68575278,
0.67564837, 0.64264062, 0.71034018, 0.51667228, 0.66722802,
0.68743685, 0.69686763, 0.63691479, 0.67329067, 0.68474234,
0.68036376, 0.63051533, 0.64971371, 0.71303469, 0.59211856,
0.68238464, 0.60188616, 0.70899293, 0. , 0.6298417 ,
0.65813405, 0.65813405, 0.68642641, 0.67800606, 0.7329067 ,
0.66352307, 0.68844729, 0.65308185, 0.52846076, 0.65746042,
0.71135062, 0.7059616 , 0.69282587, 0.67800606, 0.71438195,
0.67733244, 0.65106096, 0.661839 , 0.67632199, 0.63118895,
1. , 0.66116538, 0.63624116, 0.69989896, 0.69922533,
0.6200741 , 0.77736612, 0.66958572, 0.71202425, 0.64163018,
0.63691479, 0.64735601, 0.64937689, 0.70697204, 0.6813742 ,
0.66554395, 0.67463793, 0.70966655, 0.62142135]),
'split4_train_roc_auc_score': array([0.67628185, 0.73660851, 0.70946205, 0.73833448, 0.72922156,
0.72803837, 0.72974663, 0.71304278, 0.69910396, 0.72906525,
0.72305236, 0.72225281, 0.72979081, 0.72901847, 0.72561686,
0.73125109, 0.70970424, 0.73116697, 0.71297923, 0.73022199,
0.72642596, 0.73273308, 0.71102248, 0.5 , 0.73649325,
0.73509926, 0.73512064, 0.72859647, 0.72583985, 0.70333506,
0.7315093 , 0.72357895, 0.72990119, 0.69924812, 0.73715653,
0.71096153, 0.67742183, 0.71696144, 0.72898208, 0.71638896,
0.73193713, 0.7305153 , 0.7287862 , 0.72850344, 0.72707753,
0.5 , 0.73325893, 0.73375052, 0.72389672, 0.71996878,
0.73421729, 0.67808924, 0.73590867, 0.71076395, 0.73120798,
0.71938083, 0.72787193, 0.73317889, 0.70071359, 0.72683989,
0.73326789, 0.73187242, 0.71285559, 0.72689639]),
'std_fit_time': array([ 9.55591978, 38.78627526, 42.77974378, 28.80870091, 7.51955155,
6.47681714, 16.12331419, 52.35156764, 12.11289074, 11.14451946,
3.46580088, 6.84573783, 1.9187663 , 11.83772691, 19.10549283,
32.07578725, 9.23158141, 21.09325156, 66.50479657, 13.6652202 ,
24.27784246, 6.74279631, 49.55246899, 5.24876387, 36.72516275,
10.58016943, 55.63630324, 4.05581845, 5.85394554, 9.72607113,
38.55740524, 5.47683185, 19.82841326, 5.99117874, 70.94387215,
17.6914669 , 0.44523821, 44.58256972, 13.97144978, 16.82337673,
11.73868969, 14.26484731, 9.93361587, 29.0932304 , 5.39685493,
4.7939115 , 14.20067175, 1.23638558, 16.37670038, 2.1355177 ,
8.98804719, 8.10174758, 14.53091054, 72.18667882, 32.67155009,
8.02887867, 16.33845509, 8.69983329, 0.57649153, 1.60839863,
30.59197664, 35.21019253, 5.30584551, 7.77813193]),
'std_score_time': array([0.01592433, 0.01561245, 0.00952453, 0.0173178 , 0.01410845,
0.00823105, 0.01569103, 0.02171895, 0.01868987, 0.01366252,
0.02016532, 0.01832044, 0.01237464, 0.00568729, 0.01358747,
0.01497355, 0.00978536, 0.00869258, 0.00696509, 0.01557709,
0.01288276, 0.0153484 , 0.00570296, 0.0119586 , 0.01294698,
0.02367343, 0.00544206, 0.02036147, 0.00650833, 0.01002184,
0.01084726, 0.0268889 , 0.00882857, 0.00613621, 0.01555662,
0.01851585, 0.01240622, 0.0093098 , 0.02238715, 0.01050826,
0.01098161, 0.0064774 , 0.00706173, 0.01321246, 0.00977176,
0.02167686, 0.00723703, 0.00499105, 0.00971381, 0.0161403 ,
0.01413122, 0.01074589, 0.00946423, 0.01346098, 0.01145903,
0.01965316, 0.01126411, 0.01466676, 0.01069918, 0.00841139,
0.01112975, 0.03287817, 0.02176054, 0.08338507]),
'std_test_accuracy_score': array([0.08120113, 0.01689325, 0.01406656, 0.01141559, 0.01370439,
0.00599314, 0.02501584, 0.00681847, 0.25049932, 0.01175467,
0.01155377, 0.0508374 , 0.02314924, 0.01403667, 0.01588892,
0.01821168, 0.23031304, 0.00461804, 0.01121535, 0.02082482,
0.01698189, 0.07063309, 0.01855901, 0.37614824, 0.00556919,
0.01295338, 0.01184833, 0.01513853, 0.01309038, 0.28440573,
0.00968718, 0.02317791, 0.01377337, 0.05273447, 0.01318224,
0.00861399, 0.00838787, 0.02042662, 0.02008256, 0.04138749,
0.02713796, 0.01861031, 0.01092168, 0.01424058, 0.12883689,
0.30963007, 0.01446327, 0.01745435, 0.02483421, 0.01054007,
0.01200172, 0.21110038, 0.00946257, 0.03018386, 0.01082639,
0.05998125, 0.00446079, 0.01790371, 0.01499358, 0.01411095,
0.01420438, 0.01335888, 0.16385535, 0.02111266]),
'std_test_average_precision_score': array([0.02848728, 0.00695509, 0.00427244, 0.01373422, 0.00816577,
0.00870586, 0.01447477, 0.00776176, 0.04304511, 0.01289484,
0.00947792, 0.03204064, 0.01652233, 0.01243508, 0.01167368,
0.00884182, 0.04224836, 0.00955862, 0.00995805, 0.00423348,
0.00768272, 0.03053567, 0.01249525, 0.00093158, 0.01199594,
0.01038898, 0.01343326, 0.01324746, 0.01184146, 0.05336441,
0.00907134, 0.01018378, 0.01619365, 0.01368207, 0.00811307,
0.00797374, 0.00203479, 0.01576032, 0.00571632, 0.01918676,
0.01943602, 0.01443881, 0.00838026, 0.01247424, 0.03940187,
0.00458872, 0.00593179, 0.00763497, 0.01569443, 0.00928753,
0.00832274, 0.01927812, 0.01300336, 0.01611656, 0.00836862,
0.02824757, 0.00749628, 0.01016575, 0.00688316, 0.01153568,
0.01382723, 0.01400221, 0.03169467, 0.00772492]),
'std_test_f1_score': array([0.04717119, 0.00988754, 0.00582427, 0.01818961, 0.01157378,
0.01147579, 0.02151356, 0.01067613, 0.11124963, 0.01744317,
0.01346508, 0.04728326, 0.02373772, 0.01721496, 0.01668381,
0.012991 , 0.0785546 , 0.01235943, 0.01414564, 0.00743852,
0.01142063, 0.04780229, 0.01863899, 0.09128707, 0.01471753,
0.01411324, 0.01785984, 0.01853197, 0.01643702, 0.08856126,
0.01189445, 0.01578804, 0.02210099, 0.02451178, 0.01082913,
0.01114802, 0.00297589, 0.02303798, 0.00960787, 0.0301833 ,
0.0280556 , 0.02015676, 0.01155343, 0.0175578 , 0.06630148,
0.08479773, 0.00823084, 0.01107662, 0.0232767 , 0.01308243,
0.01040174, 0.03427584, 0.01725145, 0.0249 , 0.01093185,
0.03988437, 0.00866125, 0.01468536, 0.01074257, 0.01616415,
0.01872196, 0.01928732, 0.05010079, 0.01164643]),
'std_test_precision_score': array([0.04687016, 0.01570788, 0.00523636, 0.01703624, 0.01111971,
0.00823544, 0.02677947, 0.00740104, 0.1255439 , 0.01488116,
0.0105455 , 0.05886661, 0.02849092, 0.01581133, 0.01574223,
0.01512872, 0.07442172, 0.00871126, 0.01068376, 0.01902706,
0.01243887, 0.0689956 , 0.015698 , 0.42659957, 0.00948347,
0.01465324, 0.01563443, 0.01703145, 0.01481839, 0.13956867,
0.01058636, 0.01892163, 0.02062128, 0.03830781, 0.01272613,
0.00786971, 0.00221329, 0.01993545, 0.01207923, 0.03826132,
0.03263501, 0.02217034, 0.01094895, 0.01661736, 0.08407117,
0.36265757, 0.00939531, 0.01467295, 0.02358114, 0.01011279,
0.01192257, 0.02466281, 0.01398993, 0.02343914, 0.00950971,
0.08530579, 0.00354874, 0.01770221, 0.00890015, 0.01462944,
0.01758559, 0.01717096, 0.11740732, 0.01678467]),
'std_test_recall_score': array([0.04743545, 0.04281237, 0.03225811, 0.02276892, 0.03165579,
0.02369813, 0.03576474, 0.02167206, 0.32428708, 0.02510898,
0.0229761 , 0.03195878, 0.02515706, 0.02814952, 0.0279394 ,
0.03588919, 0.09771083, 0.02303909, 0.02079561, 0.05103947,
0.03293471, 0.07357138, 0.01784039, 0.4827053 , 0.03620984,
0.03033102, 0.0311971 , 0.02306368, 0.02139831, 0.23370181,
0.02939845, 0.04288608, 0.01870933, 0.07753748, 0.03640516,
0.02410895, 0.01813867, 0.01532557, 0.03408898, 0.04535918,
0.02613879, 0.03271974, 0.02489195, 0.0180836 , 0.10706059,
0.45127794, 0.0349366 , 0.04030495, 0.02851246, 0.02325628,
0.04095208, 0.22730853, 0.02422558, 0.02300106, 0.03302011,
0.08899631, 0.03363278, 0.03380965, 0.01733684, 0.02701752,
0.03204261, 0.0214954 , 0.18894401, 0.04782817]),
'std_test_roc_auc_score': array([0.0273586 , 0.01126169, 0.00865766, 0.01285712, 0.01156148,
0.01164735, 0.01124306, 0.01167991, 0.07081819, 0.01430143,
0.01288372, 0.01567024, 0.01128924, 0.01413821, 0.0132648 ,
0.01060918, 0.10253815, 0.01183796, 0.01371422, 0.01120383,
0.00994204, 0.01704239, 0.01529633, 0.00327369, 0.01564249,
0.0119411 , 0.0152356 , 0.01369614, 0.01247142, 0.06920312,
0.01187038, 0.01306101, 0.01481102, 0.01438325, 0.01168122,
0.0121982 , 0.00464551, 0.01740838, 0.00690312, 0.01097509,
0.01381257, 0.01343064, 0.01000663, 0.01212272, 0.02806014,
0.0580667 , 0.00964762, 0.01106626, 0.01524513, 0.01275557,
0.01332689, 0.05971635, 0.01434785, 0.01691785, 0.01218377,
0.00814684, 0.01267562, 0.0102399 , 0.00970493, 0.01340474,
0.01516086, 0.0145127 , 0.05059578, 0.01232903]),
'std_train_accuracy_score': array([0.07414053, 0.01960899, 0.01846041, 0.01062959, 0.01402574,
0.005213 , 0.02480279, 0.00325895, 0.25114819, 0.00798608,
0.01206379, 0.045666 , 0.02255292, 0.01098098, 0.01111771,
0.02003146, 0.23318293, 0.00613487, 0.00355466, 0.02658405,
0.01775494, 0.06998099, 0.01110398, 0.37585595, 0.00928379,
0.01271191, 0.01177298, 0.00879947, 0.01206105, 0.28412487,
0.00816732, 0.02342378, 0.00779453, 0.05439058, 0.01554727,
0.00273103, 0.01208139, 0.01686221, 0.02341127, 0.03612815,
0.02336552, 0.01853328, 0.00981209, 0.00813482, 0.12807054,
0.30909148, 0.01628195, 0.02064875, 0.01824654, 0.00815635,
0.01573379, 0.21326702, 0.00723676, 0.02398273, 0.0153593 ,
0.05955848, 0.01134791, 0.01975585, 0.01747935, 0.0095533 ,
0.01033557, 0.00841133, 0.16402079, 0.02622002]),
'std_train_average_precision_score': array([0.02216068, 0.01109289, 0.00631713, 0.00668249, 0.00632497,
0.00448229, 0.01063535, 0.00195363, 0.05006613, 0.00289591,
0.00490189, 0.01871115, 0.00852476, 0.00489491, 0.0047065 ,
0.00951475, 0.04586851, 0.00418775, 0.0008835 , 0.0143588 ,
0.00835389, 0.02948636, 0.00360344, 0.00055108, 0.00714931,
0.00673208, 0.0065658 , 0.00307008, 0.00504104, 0.0493728 ,
0.00493874, 0.00949482, 0.00330773, 0.01475142, 0.0102166 ,
0.00174619, 0.00423078, 0.0055905 , 0.01060149, 0.01351599,
0.01019653, 0.00898304, 0.00522478, 0.00347089, 0.03651923,
0.00487687, 0.00870468, 0.01038485, 0.00724335, 0.00349712,
0.009262 , 0.01966145, 0.0039047 , 0.00891135, 0.00754712,
0.02575912, 0.00541056, 0.00962027, 0.00562789, 0.00420031,
0.0056777 , 0.00328622, 0.03276618, 0.01169514]),
'std_train_f1_score': array([0.03854958, 0.01683084, 0.0107881 , 0.00970718, 0.01015704,
0.00633232, 0.01738555, 0.00285433, 0.11318784, 0.00487939,
0.0079729 , 0.0304486 , 0.01425596, 0.0078285 , 0.00771523,
0.01522298, 0.08447784, 0.00607737, 0.00111977, 0.0219182 ,
0.01341444, 0.0463491 , 0.00624506, 0.0917344 , 0.00997136,
0.01035508, 0.00987011, 0.00524663, 0.00828982, 0.08314975,
0.00744562, 0.01580464, 0.00500313, 0.02642794, 0.01509578,
0.00256064, 0.00716721, 0.00978165, 0.01709684, 0.02283324,
0.01645163, 0.01421713, 0.00816973, 0.00553768, 0.06307765,
0.08372835, 0.01353392, 0.01630969, 0.01206441, 0.00562379,
0.01380028, 0.03372692, 0.00582754, 0.01506051, 0.01191576,
0.03850133, 0.00860704, 0.01523492, 0.00962813, 0.00679416,
0.00861138, 0.00543887, 0.0511655 , 0.01896381]),
'std_train_precision_score': array([0.04037899, 0.02265345, 0.01034012, 0.01251447, 0.01133242,
0.00563348, 0.02415862, 0.0023038 , 0.12921945, 0.00677923,
0.00829921, 0.04721629, 0.02322345, 0.01028791, 0.00956058,
0.01756572, 0.07895754, 0.00669241, 0.00129869, 0.03168492,
0.01445294, 0.06710284, 0.0068449 , 0.42955319, 0.01213349,
0.01371676, 0.01298756, 0.0074393 , 0.01025811, 0.13219434,
0.00879946, 0.02073306, 0.00795748, 0.04315904, 0.01795696,
0.00205518, 0.00569109, 0.01144288, 0.0177576 , 0.03134108,
0.0239649 , 0.01921878, 0.00924172, 0.00754358, 0.07813776,
0.36209415, 0.01479496, 0.02093287, 0.01557726, 0.00558452,
0.01954681, 0.028226 , 0.00702915, 0.01662407, 0.01491949,
0.0770905 , 0.01043575, 0.01908084, 0.00883701, 0.00831242,
0.01162866, 0.00744155, 0.12030757, 0.02491749]),
'std_train_recall_score': array([0.05475019, 0.01827845, 0.01013544, 0.01400531, 0.01359636,
0.00865734, 0.02823465, 0.00584311, 0.31620802, 0.01155976,
0.01207176, 0.05959822, 0.0330802 , 0.01410102, 0.01216532,
0.01583659, 0.07882739, 0.00803692, 0.00875285, 0.02846357,
0.01021923, 0.07455479, 0.01077451, 0.48332056, 0.01137982,
0.01506971, 0.0147759 , 0.01285008, 0.01289292, 0.24024522,
0.00773225, 0.02550002, 0.01642165, 0.07804232, 0.00981244,
0.00496184, 0.01258838, 0.01586494, 0.01256413, 0.04179523,
0.02940047, 0.02034889, 0.00673328, 0.01247657, 0.11098202,
0.44694631, 0.00998547, 0.01983398, 0.02052489, 0.00785264,
0.01816159, 0.20649697, 0.01108743, 0.01903708, 0.01471818,
0.08457261, 0.01119246, 0.019309 , 0.01193592, 0.01051596,
0.01238964, 0.01163935, 0.18245949, 0.02732914]),
'std_train_roc_auc_score': array([0.01922338, 0.00366266, 0.00655043, 0.00419041, 0.00482931,
0.00497517, 0.00278473, 0.0027973 , 0.0734647 , 0.00177698,
0.00437075, 0.00469031, 0.00347651, 0.00281694, 0.00261335,
0.00584427, 0.10781111, 0.00354717, 0.00210385, 0.00414488,
0.00582562, 0.01304577, 0.00262325, 0.00210818, 0.00477158,
0.003654 , 0.00367546, 0.00179495, 0.00283955, 0.06955087,
0.0032724 , 0.00427184, 0.00402894, 0.00823132, 0.00558992,
0.00248665, 0.0074496 , 0.00262473, 0.00815668, 0.00539125,
0.00370653, 0.00373565, 0.00334999, 0.00266351, 0.02948973,
0.05551219, 0.00589714, 0.00433316, 0.00286106, 0.00349435,
0.00377901, 0.05916925, 0.00340941, 0.00549393, 0.00348014,
0.01003527, 0.0031119 , 0.00501803, 0.00694062, 0.00244515,
0.00271608, 0.00190455, 0.04318998, 0.00438648])}
val_acc_scores = mlp_best_model.cv_results_['mean_test_accuracy_score']
train_acc_scores = mlp_best_model.cv_results_['mean_train_accuracy_score']
val_roc_auc_scores = mlp_best_model.cv_results_['mean_test_roc_auc_score']
train_roc_auc_scores = mlp_best_model.cv_results_['mean_train_roc_auc_score']
val_f1_scores = mlp_best_model.cv_results_['mean_test_f1_score']
train_f1_scores = mlp_best_model.cv_results_['mean_train_f1_score']
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20,6))
ax[0].plot(val_acc_scores, label='Test Score',linewidth=7, alpha=0.6)
ax[0].plot(train_acc_scores, label='Train Score',linewidth=1)
ax[0].legend(loc='best')
ax[0].set_title('Mean Accuracy Score')
ax[1].plot(val_roc_auc_scores, label='Test Score',linewidth=7, alpha=0.6)
ax[1].plot(train_roc_auc_scores, label='Train Score',linewidth=1)
ax[1].legend(loc='best')
ax[1].set_title('Mean ROC-AUC Scores')
ax[2].plot(val_f1_scores, label='Test Score',linewidth=7, alpha=0.6)
ax[2].plot(train_f1_scores, label='Train Score',linewidth=1)
ax[2].legend(loc='best')
ax[2].set_title('Mean F1 Scores')
plt.show()
def mlp_learning_curve(model):
net = model
loss = [net.history[:,'train_loss'][i] for i in range(len(net.history[:,'train_loss']))]
valid_loss = [net.history[:,'valid_loss'][i] for i in range(len(net.history[:,'valid_loss']))]
val_acc = [net.history[:,'valid_acc'][i] for i in range(len(net.history[:,'valid_acc']))]
train_acc = [net.history[:,'train_acc'][i] for i in range(len(net.history[:,'train_acc']))]
plt.figure(figsize=(16,9))
plt.plot(loss, label='Training Loss')
lines = plt.plot(valid_loss, label="Validation Loss")
# Find position of lowest validation loss (Early Stopping)
minposs = valid_loss.index(min(valid_loss))+1
plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint (Epoch:' + str(minposs) + ')')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend(loc='upper left')
ax2 = plt.gca().twinx()
ax2.plot(val_acc, label='Validation Accuracy', color='g')
ax2.plot(train_acc, label='Train Accuracy', color='r')
plt.ylabel('Accuracy')
plt.legend(loc='upper center')
plt.tight_layout()
mlp_learning_curve(nnet_selection)
df_RandomSearchResults_mlp = pd.concat([pd.DataFrame(mlp_model_selection.cv_results_["params"]),
pd.DataFrame(mlp_model_selection.cv_results_['mean_test_accuracy_score'], columns=['Accuracy']),
pd.DataFrame(mlp_model_selection.cv_results_['mean_test_roc_auc_score'],columns=['ROC_AUC']),
pd.DataFrame(mlp_model_selection.cv_results_['mean_test_average_precision_score'], columns=['PR_AUC'])],axis=1)
df_RandomSearchResults_mlp.sort_values(by=['ROC_AUC'],ascending=False)
| nnet__batch_size | nnet__lr | nnet__max_epochs | nnet__module__dropout | nnet__module__hidden_dim | nnet__optimizer | nnet__optimizer__weight_decay | Accuracy | ROC_AUC | PR_AUC | |
|---|---|---|---|---|---|---|---|---|---|---|
| 25 | 32 | 0.020723 | 53 | 0.1 | 44 | <class '__main__.AdaBound'> | 0.000800 | 0.794962 | 0.729234 | 0.237395 |
| 41 | 32 | 0.000745 | 177 | 0.4 | 63 | <class 'torch.optim.adam.Adam'> | 0.000668 | 0.794780 | 0.728076 | 0.236282 |
| 24 | 32 | 0.000346 | 301 | 0.1 | 75 | <class 'torch.optim.adam.Adam'> | 0.000007 | 0.800637 | 0.727963 | 0.238812 |
| 26 | 32 | 0.000173 | 180 | 0.2 | 66 | <class 'torch.optim.adam.Adam'> | 0.000011 | 0.785918 | 0.727898 | 0.232635 |
| 55 | 32 | 0.015692 | 288 | 0.0 | 38 | <class 'torch.optim.adam.Adam'> | 0.001051 | 0.788862 | 0.727444 | 0.233665 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 13 | 32 | 0.000118 | 153 | 0.9 | 54 | <class 'torch.optim.adam.Adam'> | 0.015320 | 0.748255 | 0.703154 | 0.210409 |
| 45 | 32 | 0.096460 | 121 | 0.7 | 43 | <class 'torch.optim.adam.Adam'> | 0.000033 | 0.722367 | 0.695615 | 0.219455 |
| 58 | 32 | 0.000019 | 127 | 0.8 | 44 | <class 'torch.optim.adam.Adam'> | 0.000012 | 0.656571 | 0.684652 | 0.180528 |
| 36 | 32 | 0.000016 | 57 | 0.6 | 55 | <class 'torch.optim.adam.Adam'> | 0.000819 | 0.645645 | 0.668851 | 0.171901 |
| 23 | 32 | 0.034611 | 126 | 0.9 | 62 | <class 'torch.optim.adam.Adam'> | 0.018718 | 0.590501 | 0.579030 | 0.145428 |
64 rows × 10 columns
Save hyperparameter results to CSV file (Top 20)
# Rename columns to a more explainable name
df_RandomSearchResults_mlp.rename({'nnet__lr': 'Learning_Rate', 'nnet__module__dropout': 'Dropout',
'nnet__module__hidden_dim': 'Hidden_Nodes', 'nnet__optimizer__weight_decay': 'Weight_Decay',
'nnet__optimizer': 'Optimizer', 'nnet__max_epochs': 'Max Epochs'},
axis=1, inplace=True)
df_RandomSearchResults_mlp.sort_values(by=['ROC_AUC'],ascending=False).head(20).to_csv('mpl_random_search_hyperparameters.csv')
cols = ['Learning_Rate', 'Weight_Decay', 'Dropout', 'Hidden_Nodes', 'Max Epochs', 'ROC_AUC']
fig = px.parallel_coordinates(df_RandomSearchResults_mlp, color='ROC_AUC', dimensions=cols,
color_continuous_scale=px.colors.sequential.Viridis,
title="MLP Hyperparameter Search Plot (ROC_AUC)",
width=1000, height=700)
fig.show()
# Rename columns to a more explainable name
df_RandomSearchResults_mlp.rename({'nnet__lr': 'Learning_Rate', 'nnet__module__dropout': 'Dropout',
'nnet__module__hidden_dim': 'Hidden_Nodes', 'nnet__optimizer__weight_decay': 'Weight_Decay',
'nnet__max_epochs': 'Max Epochs'},
axis=1, inplace=True)
cols = ['Learning_Rate', 'Weight_Decay', 'Dropout', 'Hidden_Nodes', 'Max Epochs', 'Accuracy']
fig = px.parallel_coordinates(df_RandomSearchResults_mlp, color='Accuracy', dimensions=cols,
color_continuous_scale=px.colors.sequential.Viridis,
title="MLP Hyperparameter Search Plot (Accuracy)",
width=1000, height=700)
fig.show()
Plot table of the hyperparameters and metrics of roc_auc, pr_auc and accuracy.
y_train_preds = mlp_model_selection.predict(X_train.astype(np.float32))
Now we want to analyse how the model can be compared against the training data:
True Positives: Total number of correct predictions of a client subscription.
False Positives: The model predicted a subscription but the client did not subscribe.
True Negatives: Total number of correct predictions of clients not subscribed.
False Negative: The model predicted a non subscription but the client is subscribed.
From the above definitions we can clearly state that False Positives can generate less benefits compared to False Negatives. Therefore it is more important to have less False Positives than the opposite.
F1-Score: Support: Each class has a support number to represent the amount of examples for the class.
#gs.fit(X_test_net.astype(np.float32), y_test_net.astype(np.float32).squeeze(1))
y_train_preds = mlp_model_selection.predict(X_train.astype(np.float32))
actual_label = y_train.astype(np.float32).squeeze(1)
f1 = f1_score(actual_label, y_train_preds)
#fbeta_score = fbeta_score(actual_label, y_train_preds, average='weighted', beta=0.5)
accuracy = accuracy_score(actual_label, y_train_preds)
roc_auc = roc_auc_score(actual_label, y_train_preds)
cm = confusion_matrix(actual_label, y_train_preds)
report = classification_report(actual_label, y_train_preds)
print("Model ROC-AUC(Train Data): ", roc_auc)
print("Model F1-Score (Train Data): ", f1)
#print("Model FBeta-Score (Train Data): ", fbeta_score)
print("Model Accuracy: ", accuracy)
print("Confusion Matrix:\n", cm)
print("\nClassification Report:\n", report)
Model ROC-AUC(Train Data): 0.7315446961293876
Model F1-Score (Train Data): 0.4134590874330869
Model Accuracy: 0.7905007587253414
Confusion Matrix:
[[23614 5624]
[ 1279 2433]]
Classification Report:
precision recall f1-score support
0.0 0.95 0.81 0.87 29238
1.0 0.30 0.66 0.41 3712
accuracy 0.79 32950
macro avg 0.63 0.73 0.64 32950
weighted avg 0.88 0.79 0.82 32950
# keep probabilities for the positive outcome only
yhat_mlp = mlp_model_selection.predict(X_train.astype(np.float32))
yhat_mlp = yhat_mlp[:, 1]
# calculate roc curves
fpr, tpr, thresholds = roc_curve(y_train, yhat_mlp)
plot_roc_auc_thresholds(fpr, tpr, thresholds, y_train, yhat_mlp, plot_type='MLP')
# calculate the g-mean for each threshold
gmeans = sqrt(tpr * (1-fpr))
# locate the index of the largest g-mean
ix = argmax(gmeans)
print('Best Threshold=%f, G-Mean=%.3f' % (thresholds[ix], gmeans[ix]))
# plot the roc curve for the model
plt.plot([0,1], [0,1], linestyle='--', label='No Skill')
plt.plot(fpr, tpr, marker='.', label='Logistic')
plt.scatter(fpr[ix], tpr[ix], marker='o', color='black', label='Best')
# axis labels
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend()
# show the plot
plt.show()
Best Threshold=0.592243, G-Mean=0.728
# Data to plot precision - recall curve
precision, recall, thresholds = precision_recall_curve(y_train, y_train_preds)
# Use AUC function to calculate the area under the curve of precision recall curve
auc_precision_recall = auc(recall, precision)
print('PR-AUC:',auc_precision_recall)
PR-AUC: 0.4981158190287896
average_precision = average_precision_score(y_train.astype(np.float32), yhat_mlp)
disp = plot_precision_recall_curve(mlp_best_model, X_train.astype(np.float32), y_train.astype(np.float32))
disp.ax_.set_title('2-class Precision-Recall curve: '
'AP={0:0.2f}'.format(average_precision))
Text(0.5, 1.0, '2-class Precision-Recall curve: AP=0.39')
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
"""
Modified from:
Hands-On Machine learning with Scikit-Learn
and TensorFlow; p.89
"""
plt.figure(figsize=(8, 8))
plt.title("Precision and Recall Scores as a function of the decision threshold")
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
plt.ylabel("Score")
plt.xlabel("Decision Threshold")
plt.legend(loc='best')
plot_precision_recall_vs_threshold(precision, recall, thresholds)
y_pred_proba_mlp = mlp_model_selection.predict_proba(X_train.astype(np.float32))[:,1]
precision, recall, thresholds = precision_recall_curve(y_train, y_pred_proba_mlp)
precision_recall_threshold(precision, recall, thresholds, y_train, y_pred_proba_mlp, t=0.5, plot_type='MLP',savefig='No')
findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.
We will follow a strategy idealized by the Deep Learning book (Aaron Courville, Ian Goodfellow, and Yoshua Bengio):
We define a new instance of the Neural Network with the parameters found previously.
nnet_retrain
<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
module_=NeuralNet(
(fcl1): Linear(in_features=40, out_features=75, bias=True)
(fcl2): Linear(in_features=75, out_features=38, bias=True)
(output): Linear(in_features=38, out_features=2, bias=True)
(batchnorm1): BatchNorm1d(75, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(batchnorm2): BatchNorm1d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(dropout): Dropout(p=0.1, inplace=False)
),
)
# Save the model when finds the best train loss
def train_mlp_final_model(X_train_model,y_train_model,model,device):
# Get best hyperparameters from model selection process
nnet_batch_size = model.best_params_['nnet__batch_size']
nnet_lr = model.best_params_['nnet__lr']
nnet_module_dropout = model.best_params_['nnet__module__dropout']
nnet_module_hidden_dim = model.best_params_['nnet__module__hidden_dim']
nnet_optimizer_weight_decay = model.best_params_['nnet__optimizer__weight_decay']
nnet_optimizer = model.best_params_['nnet__optimizer']
nnet_max_epochs = model.best_params_['nnet__max_epochs']
scaler = preprocessing.StandardScaler()
over = SMOTE(sampling_strategy=0.2, random_state=2 ,k_neighbors=7)
rand_under = RandomUnderSampler(sampling_strategy='majority', random_state=2)
#monitor_losses = lambda net: all(net.history[-1, ('train_loss_best',)])
model = NeuralNet(hidden_dim=nnet_module_hidden_dim, dropout=nnet_module_dropout)
model.to(device)
# Define new instance to train
n_net_retrain = NeuralNetClassifier(
model,
max_epochs=nnet_max_epochs,
batch_size=nnet_batch_size,
criterion=nn.CrossEntropyLoss,
lr=nnet_lr,
#callbacks=[Checkpoint(monitor=monitor_losses)], # Save best train loss
iterator_train__shuffle=True, # Shuffle training data on each epoch
optimizer__weight_decay=nnet_optimizer_weight_decay,
optimizer=nnet_optimizer,
train_split=None, #Disable skorch validation split so we use the entire test set
device=device
)
# Define the Imbalanced Pipeline without SMOTE and Random Under Sampling
nnet_pipeline_retrain = imbPipeline([('scaler',scaler),
('o',over),('ru',rand_under),
('nnet', n_net_retrain)])
# Train the model using all training data
mlp_final_model = nnet_pipeline_retrain.fit(X_train_model.astype(np.float32), y_train_model.astype(np.int64).squeeze(1))
return mlp_final_model, n_net_retrain
mlp_final_model, n_net_retrain = train_mlp_final_model(X_train,y_train,mlp_model_selection,device)
dump(mlp_final_model, 'mlp_final_model.joblib')
epoch train_loss dur
------- ------------ ------
1 0.6172 1.1915
2 0.5801 1.1584
3 0.5681 1.1641
4 0.5626 1.1317
5 0.5627 1.1327
6 0.5625 1.1350
7 0.5600 1.1171
8 0.5582 1.1924
9 0.5583 1.2024
10 0.5574 1.1339
11 0.5590 1.1180
12 0.5573 1.1911
13 0.5573 1.1635
14 0.5565 1.1460
15 0.5561 1.1659
16 0.5550 1.1440
17 0.5550 1.1575
18 0.5556 1.1631
19 0.5544 1.1372
20 0.5557 1.1423
21 0.5501 1.1750
22 0.5531 1.1246
23 0.5523 1.1426
24 0.5513 1.1237
25 0.5519 1.1501
26 0.5519 1.1479
27 0.5513 1.1619
28 0.5526 1.1457
29 0.5535 1.1537
30 0.5510 1.1816
31 0.5513 1.1732
32 0.5510 1.1586
33 0.5497 1.1358
34 0.5524 1.1496
35 0.5498 1.1695
36 0.5518 1.1538
37 0.5498 1.1499
38 0.5488 1.1890
39 0.5522 1.1493
40 0.5498 1.1879
41 0.5505 1.1237
42 0.5497 1.2291
43 0.5477 1.1634
44 0.5496 1.1713
45 0.5499 1.1383
46 0.5498 1.1715
47 0.5524 1.1708
48 0.5522 1.1430
49 0.5475 1.1239
50 0.5493 1.1305
51 0.5499 1.1775
52 0.5472 1.1279
53 0.5487 1.1378
['mlp_final_model.joblib']
Now that we have retrained the model we can save the model that will be applied against the test set for the algorithm comparison
dump(mlp_final_model, 'mlp_best_model.joblib')
def mlp_save_best_hyperparam(model):
nnet_batch_size = model.best_params_['nnet__batch_size']
nnet_lr = model.best_params_['nnet__lr']
nnet_module_dropout = model.best_params_['nnet__module__dropout']
nnet_module_hidden_dim = model.best_params_['nnet__module__hidden_dim']
nnet_optimizer_weight_decay = model.best_params_['nnet__optimizer__weight_decay']
nnet_optimizer = model.best_params_['nnet__optimizer']
nnet_max_epochs = model.best_params_['nnet__max_epochs']
dict_mlp_best_hyp = {'batch_size': nnet_batch_size,
'lr': nnet_lr,
'dropout': nnet_module_dropout,
'hidden_dim': nnet_module_hidden_dim,
'weight_decay': nnet_optimizer_weight_decay,
'optimizer': nnet_optimizer,
'max_epochs': nnet_max_epochs}
# Saving MLP best Hyperparameters
with open('mlp_best_hyperparam.pkl', 'wb') as f:
pkl.dump(dict_mlp_best_hyp, f)
mlp_save_best_hyperparam(mlp_model_selection)
# Save best model trained on entire training set
n_net_retrain.save_params(f_params='mlp_best_network.pkl', f_optimizer='mlp_opt.pkl', f_history='mlp_history.json')
nnet_retrain
<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
module_=NeuralNet(
(fcl1): Linear(in_features=40, out_features=96, bias=True)
(fcl2): Linear(in_features=96, out_features=96, bias=True)
(output): Linear(in_features=96, out_features=2, bias=True)
(batchnorm1): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(batchnorm2): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(dropout): Dropout(p=0.3, inplace=False)
),
)
with open('mlp_best_hyperparam.pkl', 'rb') as handle:
mlp_best_hyperparam = pkl.load(handle)
# Define new model instance using model parameters, optimizer, and history files
new_mlp_net = NeuralNetClassifier(
module=NeuralNet(hidden_dim=mlp_best_hyperparam['hidden_dim'],
dropout=mlp_best_hyperparam['dropout']),
max_epochs=mlp_best_hyperparam['dropout'],
batch_size=mlp_best_hyperparam['batch_size'],
criterion=nn.CrossEntropyLoss,
optimizer=torch.optim.Adam
)
# Initialize
new_mlp_net.initialize() # This is important!
new_mlp_net.load_params(f_params='mlp_best_model.pkl', f_optimizer='mlp_opt.pkl', f_history='mlp_history.json')
By observing the learning curves, I can tell if the Neural Network overfitted or underfitted the data. Overfit : if the training loss curve is significantly lower than the validation loss curve. Underfit: if both the training loss curve and the validation loss curve are very high loss. Ideal: both the training loss and validation loss curves have a minimal gap between them and converge to a very low loss.
from sklearn.metrics import confusion_matrix, accuracy_score
accuracy_score(y_test_net.astype(np.float32), gs.predict(X_test_net.astype(np.float32)))
0.8909929594561787
# generate a no skill prediction (majority class)
ns_probs = [0 for _ in range(len(y_test_net.astype(np.float32).squeeze(1)))]
# predict probabilities and keep probabilities for the positive outcome only
#lr_probs = model.predict_proba(testX)
y_scores = gs.predict_proba(X_test_net.astype(np.float32))[:, 1]
#lr_probs = lr_probs[:, 1]
# calculate scores
ns_auc = roc_auc_score(y_test_net.astype(np.float32).squeeze(1), ns_probs)
lr_auc = roc_auc_score(y_test_net.astype(np.float32).squeeze(1), y_scores)
# summarize scores
print('No Skill: ROC AUC=%.3f' % (ns_auc))
print('Logistic: ROC AUC=%.3f' % (lr_auc))
# calculate roc curves
ns_fpr, ns_tpr, _ = roc_curve(y_test_net.astype(np.float32).squeeze(1), ns_probs)
lr_fpr, lr_tpr, _ = roc_curve(y_test_net.astype(np.float32).squeeze(1), y_scores)
# plot the roc curve for the model
plt.plot(ns_fpr, ns_tpr, linestyle='--', label='No Skill')
plt.plot(lr_fpr, lr_tpr, marker='.', label='MLP')
# axis labels
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
# show the legend
plt.legend()
# show the plot
plt.show()
No Skill: ROC AUC=0.500 Logistic: ROC AUC=0.808
def precision_recall_curve_plot(precisions, recalls, thresholds):
fig, ax = plt.subplots(figsize=(12,8))
plt.plot(thresholds, precisions[:-1], "r--", label="Precisions")
plt.plot(thresholds, recalls[:-1], "#424242", label="Recalls")
plt.title("Precision and Recall \n Tradeoff", fontsize=18)
plt.ylabel("Level of Precision and Recall", fontsize=16)
plt.xlabel("Thresholds", fontsize=16)
plt.legend(loc="best", fontsize=14)
plt.xlim([-0.5, 1.5])
plt.ylim([0, 1])
plt.axvline(x=0.13, linewidth=3, color="#0B3861")
plt.annotate('Best Precision and \n Recall Balance \n is at 0.6 \n threshold ', xy=(0.6, 0.49), xytext=(55, -40),
textcoords="offset points",
arrowprops=dict(facecolor='black', shrink=0.05),
fontsize=12,
color='k')
#precision_recall_curve(lr_precision, lr_recall, lr_thresholds)
#plt.show()
Confusion Matrix of the best selected MLP model (Training with entire training data)
mlp_best_model = load('mlp_best_model.joblib')
y_pred_mlp = mlp_best_model.predict(X_train.astype(np.float32))
conf_matrix = confusion_matrix(y_train.ravel(), y_pred_mlp)
plot_confusion_matrix(conf_matrix, target_names= ['Not Subscribed', 'Subscribed'],plot_type='MLP')
print("Before OverSampling, counts of label '1': {}".format(sum(y==1)))
print("Before OverSampling, counts of label '0': {} \n".format(sum(y==0)))
oversample = SMOTE(random_state=2)
X_sm, Y_sm = oversample.fit_resample(X_scaled, y.ravel())
X_sm.shape, Y_sm.shape
Before OverSampling, counts of label '1': [4640] Before OverSampling, counts of label '0': [36548]
((73096, 8), (73096,))
Before applying optimization we decided to apply a linear kernel to test validation set performance using the same dataset as the one used for training the neural network. We will use the same pipeline (imblearn) and run the sklearn implementation of SVC with a linear kernel against the training set.
# SVC(gamma='auto')
1/40
0.025
# Define resampling technique (SMOTE and RandomUndersampler)
over = SMOTE(sampling_strategy=0.3, random_state=2 ,k_neighbors=11)
under = RandomUnderSampler(sampling_strategy=0.5, random_state=2)
# Create a Imbalance Pipeline with Over Sampling and Under Sampling for SVM
svm_pipeline = imbPipeline([('scaler',scaler),
('o', over),('u', under),
('svc', SVC(gamma='auto'))])
svm_model = svm_pipeline
svm_model.fit(X_train_net, y_train_net)
Pipeline(steps=[('scaler', StandardScaler()),
('o',
SMOTE(k_neighbors=11, random_state=2, sampling_strategy=0.3)),
('u',
RandomUnderSampler(random_state=2, sampling_strategy=0.5)),
('svc', SVC(gamma='auto'))])
def plot_learning_curve(estimator, title, X, y, axes=None, ylim=None, cv=None,
n_jobs=None, train_sizes=np.linspace(.1, 1.0, 5)):
if axes is None:
_, axes = plt.subplots(3, 1, figsize=(20, 5))
axes[0].set_title(title)
if ylim is not None:
axes[0].set_ylim(*ylim)
axes[0].set_xlabel("Training examples")
axes[0].set_ylabel("Score")
train_sizes, train_scores, test_scores, fit_times, _ = \
learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs,
train_sizes=train_sizes,
return_times=True)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
fit_times_mean = np.mean(fit_times, axis=1)
fit_times_std = np.std(fit_times, axis=1)
# Plot learning curve
axes[0].grid()
axes[0].fill_between(train_sizes, train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std, alpha=0.1,
color="r")
axes[0].fill_between(train_sizes, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.1,
color="g")
axes[0].plot(train_sizes, train_scores_mean, 'o-', color="r",
label="Training score")
axes[0].plot(train_sizes, test_scores_mean, 'o-', color="g",
label="Cross-validation score")
axes[0].legend(loc="best")
# Plot n_samples vs fit_times
axes[1].grid()
axes[1].plot(train_sizes, fit_times_mean, 'o-')
axes[1].fill_between(train_sizes, fit_times_mean - fit_times_std,
fit_times_mean + fit_times_std, alpha=0.1)
axes[1].set_xlabel("Training examples")
axes[1].set_ylabel("fit_times")
axes[1].set_title("Scalability of the model")
# Plot fit_time vs score
axes[2].grid()
axes[2].plot(fit_times_mean, test_scores_mean, 'o-')
axes[2].fill_between(fit_times_mean, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.1)
axes[2].set_xlabel("fit_times")
axes[2].set_ylabel("Score")
axes[2].set_title("Performance of the model")
return plt
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
title = r"Learning Curves of SVM (Linear kernel) with default parameters: C = 1 (default), $\gamma=0.0025$)"
# SVC is expensive so we can perform a lower number of CV iterations of 5
svm_linear_split = 5
cv = StratifiedKFold(n_splits=svm_linear_split, shuffle=True, random_state=2)
starttime = time.time()
plot_learning_curve(svm_model, title, X_train_net, y_train_net, axes=axes, ylim=(0.7, 1.01),
cv=cv, n_jobs=4)
totaltime = time.time() - starttime
print("SVM (Linear kernel) took %.2f seconds for %d CV folds of K={}" % ((totaltime), svm_linear_split))
plt.show()
SVM (Linear kernel) took 369.04 seconds for 5 CV folds of K={}
# fit a SVM model to the data
svm_model = LinearSVC()
start = time.time()
model.fit(X_sm, Y_sm.ravel())
lsvm_time = time.time() - start
lsvm_score = 100 * model.score(X_sm, Y_sm)
# Perform predictions
expected = Y_sm
predicted = model.predict(X_sm)
# Summarize the fit of the model
print(metrics.classification_report(expected, predicted))
print(metrics.confusion_matrix(expected, predicted))
print(f"Linear SVM score on raw features: {lsvm_score:.2f}%")
print('Linear SVM time on raw features:',lsvm_time)
C:\Users\user\AppData\Local\Programs\Python\Python38\lib\site-packages\sklearn\svm\_base.py:985: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
warnings.warn("Liblinear failed to converge, increase "
precision recall f1-score support
0 0.71 0.73 0.72 36548
1 0.72 0.71 0.72 36548
accuracy 0.72 73096
macro avg 0.72 0.72 0.72 73096
weighted avg 0.72 0.72 0.72 73096
[[26558 9990]
[10629 25919]]
Linear SVM score on raw features: 71.79%
Linear SVM time on raw features: 14.083805799484253
#fit a SVM model to the data
svm_model = SVC(kernel='rbf')
start = time.time()
svm_model.fit(X_sm, Y_sm.ravel())
lsvm_time = time.time() - start
lsvm_score = 100 * model.score(X_sm, Y_sm)
# Perform predictions
expected = Y_sm
predicted = model.predict(X_sm)
# Summarize the fit of the model
print(metrics.classification_report(expected, predicted))
We will use the same pipeline as the one used before and the only difference would be the hyperparameters and the heuristic of the model. Due to the large size of the training set the SVM algorithm can take a long time to complete the hyperparameter search. The fit time scales at least quadratically with the number of samples and may be impractical beyond 10000 samples according to studies.
Due to the large size of the training set (32950 observations) we opted to optimise using a stratified sample of the training set defined to 70%, which will maintain the same proportion of classes. After that we apply Standardisation, SMOTE and UnderSampling (only runs for the training set) enclosed in a pipeline to run a RandomizedSearch cross-validation using SVM with several parameter configurations for the 3 kernel initially selected (linear, poly, rbf) and C and gamma intervals as well.
Then we will run the same configuration using RandomizedSearch cross-validation and compare if the results outperforms GridSearch.
At last we will select the best parameters and train again using the entire training set and use the model object to predict the results against the original test set previously defined.
X_train_svm, X_test_svm, y_train_svm, y_test_svm = train_test_split(X_train,
y_train,
test_size=0.75,
random_state=42,
stratify=y_train)
print('Minority class percentage in 50% stratified sample',sum(y_train_svm==1)/y_train_svm.shape[0])
print('Minority class observations in 50% stratified sample',X_train_svm.shape)
Minority class percentage in 50% stratified sample [0.11266238] Minority class observations in 50% stratified sample (8237, 40)
# svm_rs = GridSearchCV(svm_smp_pipeline, param_grid, cv = 5, scoring = "accuracy", n_jobs = -1, verbose = 42)
# svm_nonlinear = {'C': [0.0001, 0.00000001], #, 0.00001, 0.0001],
# 'gamma': [0.01], #1, 0.1, 0.01, 0.001,'auto'],
# 'kernel': ['rbf']} #'poly', 'rbf', 'sigmoid']}
# param_grid = {'clf__kernel': svm_nonlinear['kernel'], 'clf__C': svm_nonlinear['C'], 'clf__gamma': svm_nonlinear['gamma']}
# #[svm_nonlinear] #,svm_linear
# t0= time()
# grid = GridSearchCV(svm_smp_pipeline, param_grid, cv = 3, scoring = "accuracy", n_jobs = 4, verbose = 5)
# clf = grid.fit(X_train_svm, y_train_svm.ravel())
# print("done in %0.3fs" % (time() - t0))
# print("Best estimator found by grid search:")
# print(clf.best_estimator_)
Hyperparameters will be set in the following interval:
def train_svm_best_model(X_train_model, y_train_model, n_iter = 10):
scaler = preprocessing.StandardScaler()
# Define resampling technique
over = SMOTE(sampling_strategy=0.2, k_neighbors=7, random_state=2)
rand_under = RandomUnderSampler(sampling_strategy='majority', random_state=2)
# define model weights
weights = 'balanced'
#weights = {0:12.0, 1:88.0}
# Create a Imbalance Pipeline with Over Sampling and Under Sampling
svm_smp_pipeline = imbPipeline([('scaler',scaler),
('o', over), ('ru', rand_under),
#('u', under),
('svm', SVC(class_weight=weights,random_state=2,probability=False))])
# Define Hyperparameters Space
svm_params = [{'svm__kernel': ['rbf'], 'svm__gamma': loguniform(1e-4, 1e-3), 'svm__C': loguniform(1e0, 1e3)},
#{'svm__kernel': ['poly'], 'svm__degree': [2, 3, 4, 5]},
{'svm__kernel': ['linear'], 'svm__C': loguniform(1e0, 1e3)}]
scorersSVM = {
'precision_score': make_scorer(precision_score, zero_division=1),
'recall_score': make_scorer(recall_score, zero_division=1),
'accuracy_score': make_scorer(accuracy_score),
'roc_auc_score': make_scorer(roc_auc_score,average='weighted'),
'average_precision_score': make_scorer(average_precision_score),
'f1_score': make_scorer(f1_score)
}
skf = StratifiedKFold(n_splits=3,shuffle=True, random_state=2)
svm_rs = RandomizedSearchCV(svm_smp_pipeline, svm_params, refit='roc_auc_score', cv=skf, scoring=scorersSVM,
n_iter=n_iter, random_state=123, verbose=3, return_train_score=True)
starttime = time.time()
svm_model_selection = svm_rs.fit(X_train_model, y_train_model.ravel())
totaltime = time.time() - starttime
print("SVM (RandomSearch) took %.2f seconds for %d CV folds." % ((totaltime), 3))
print("best score: {}, best score: {}".format(svm_model_selection.best_params_, svm_model_selection.best_score_))
return svm_model_selection
svm_model_selection = train_svm_best_model(X_train_svm, y_train_svm, n_iter = 32)
#svm_best_model = train_svm_best_model(X_train,y_train,svm_model_selection)
# with open('svm_best_model.pkl', 'wb') as fid:
# pkl.dump(svm_best_model, fid)
# files.download('svm_best_model.pkl')
Fitting 3 folds for each of 32 candidates, totalling 96 fits
[CV 1/3] END svm__C=137.67844795855254, svm__gamma=0.00026820750502750263, svm__kernel=rbf; accuracy_score: (train=0.731, test=0.726) average_precision_score: (train=0.208, test=0.205) f1_score: (train=0.368, test=0.363) precision_score: (train=0.250, test=0.246) recall_score: (train=0.695, test=0.695) roc_auc_score: (train=0.715, test=0.712) total time= 2.0s
[CV 2/3] END svm__C=137.67844795855254, svm__gamma=0.00026820750502750263, svm__kernel=rbf; accuracy_score: (train=0.737, test=0.738) average_precision_score: (train=0.210, test=0.214) f1_score: (train=0.371, test=0.377) precision_score: (train=0.254, test=0.257) recall_score: (train=0.688, test=0.701) roc_auc_score: (train=0.716, test=0.722) total time= 2.2s
[CV 3/3] END svm__C=137.67844795855254, svm__gamma=0.00026820750502750263, svm__kernel=rbf; accuracy_score: (train=0.731, test=0.732) average_precision_score: (train=0.210, test=0.204) f1_score: (train=0.371, test=0.362) precision_score: (train=0.252, test=0.247) recall_score: (train=0.702, test=0.678) roc_auc_score: (train=0.719, test=0.708) total time= 2.0s
[CV 1/3] END svm__C=45.07588967315673, svm__gamma=0.0005241661481516257, svm__kernel=rbf; accuracy_score: (train=0.739, test=0.733) average_precision_score: (train=0.212, test=0.208) f1_score: (train=0.373, test=0.368) precision_score: (train=0.256, test=0.251) recall_score: (train=0.692, test=0.688) roc_auc_score: (train=0.718, test=0.713) total time= 2.0s
[CV 2/3] END svm__C=45.07588967315673, svm__gamma=0.0005241661481516257, svm__kernel=rbf; accuracy_score: (train=0.748, test=0.747) average_precision_score: (train=0.214, test=0.218) f1_score: (train=0.378, test=0.382) precision_score: (train=0.262, test=0.264) recall_score: (train=0.681, test=0.695) roc_auc_score: (train=0.719, test=0.724) total time= 2.0s
[CV 3/3] END svm__C=45.07588967315673, svm__gamma=0.0005241661481516257, svm__kernel=rbf; accuracy_score: (train=0.739, test=0.738) average_precision_score: (train=0.214, test=0.208) f1_score: (train=0.376, test=0.368) precision_score: (train=0.258, test=0.253) recall_score: (train=0.700, test=0.678) roc_auc_score: (train=0.722, test=0.712) total time= 2.0s
[CV 1/3] END svm__C=218.81812166859152, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 3.5min
[CV 2/3] END svm__C=218.81812166859152, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 7.0min
[CV 3/3] END svm__C=218.81812166859152, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 4.2min
[CV 1/3] END svm__C=113.36766867014445, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 1.9min
[CV 2/3] END svm__C=113.36766867014445, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 3.9min
[CV 3/3] END svm__C=113.36766867014445, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 2.2min
[CV 1/3] END svm__C=2.6293735380546712, svm__gamma=0.0002517778708684341, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 1.9s
[CV 2/3] END svm__C=2.6293735380546712, svm__gamma=0.0002517778708684341, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 2.0s
[CV 3/3] END svm__C=2.6293735380546712, svm__gamma=0.0002517778708684341, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.9s
[CV 1/3] END svm__C=153.86828816779857, svm__gamma=0.00027451889663713023, svm__kernel=rbf; accuracy_score: (train=0.737, test=0.732) average_precision_score: (train=0.211, test=0.207) f1_score: (train=0.373, test=0.367) precision_score: (train=0.255, test=0.250) recall_score: (train=0.693, test=0.690) roc_auc_score: (train=0.718, test=0.713) total time= 2.1s
[CV 2/3] END svm__C=153.86828816779857, svm__gamma=0.00027451889663713023, svm__kernel=rbf; accuracy_score: (train=0.747, test=0.747) average_precision_score: (train=0.214, test=0.218) f1_score: (train=0.378, test=0.382) precision_score: (train=0.261, test=0.264) recall_score: (train=0.681, test=0.695) roc_auc_score: (train=0.718, test=0.724) total time= 2.1s
[CV 3/3] END svm__C=153.86828816779857, svm__gamma=0.00027451889663713023, svm__kernel=rbf; accuracy_score: (train=0.737, test=0.737) average_precision_score: (train=0.213, test=0.207) f1_score: (train=0.375, test=0.368) precision_score: (train=0.256, test=0.252) recall_score: (train=0.700, test=0.678) roc_auc_score: (train=0.721, test=0.711) total time= 2.1s
[CV 1/3] END svm__C=121.41309541725505, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 2.1min
[CV 2/3] END svm__C=121.41309541725505, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 3.9min
[CV 3/3] END svm__C=121.41309541725505, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 2.4min
[CV 1/3] END svm__C=163.6764576114005, svm__gamma=0.0001522270146549397, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.205, test=0.203) f1_score: (train=0.362, test=0.359) precision_score: (train=0.244, test=0.241) recall_score: (train=0.702, test=0.704) roc_auc_score: (train=0.713, test=0.711) total time= 2.0s
[CV 2/3] END svm__C=163.6764576114005, svm__gamma=0.0001522270146549397, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.209) f1_score: (train=0.358, test=0.367) precision_score: (train=0.241, test=0.247) recall_score: (train=0.698, test=0.714) roc_auc_score: (train=0.709, test=0.719) total time= 2.1s
[CV 3/3] END svm__C=163.6764576114005, svm__gamma=0.0001522270146549397, svm__kernel=rbf; accuracy_score: (train=0.720, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.710, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 2.0s
[CV 1/3] END svm__C=1.7831154117152481, svm__gamma=0.0005156205067883839, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.202) f1_score: (train=0.361, test=0.358) precision_score: (train=0.243, test=0.240) recall_score: (train=0.701, test=0.703) roc_auc_score: (train=0.712, test=0.710) total time= 1.9s
[CV 2/3] END svm__C=1.7831154117152481, svm__gamma=0.0005156205067883839, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.357, test=0.366) precision_score: (train=0.240, test=0.247) recall_score: (train=0.696, test=0.712) roc_auc_score: (train=0.708, test=0.718) total time= 2.0s
[CV 3/3] END svm__C=1.7831154117152481, svm__gamma=0.0005156205067883839, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.9s
[CV 1/3] END svm__C=80.02079756292471, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 1.3min
[CV 2/3] END svm__C=80.02079756292471, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 3.0min
[CV 3/3] END svm__C=80.02079756292471, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.6min
[CV 1/3] END svm__C=29.940174718450645, svm__gamma=0.0005499160442140829, svm__kernel=rbf; accuracy_score: (train=0.726, test=0.722) average_precision_score: (train=0.207, test=0.205) f1_score: (train=0.366, test=0.362) precision_score: (train=0.247, test=0.244) recall_score: (train=0.701, test=0.701) roc_auc_score: (train=0.715, test=0.713) total time= 1.9s
[CV 2/3] END svm__C=29.940174718450645, svm__gamma=0.0005499160442140829, svm__kernel=rbf; accuracy_score: (train=0.729, test=0.733) average_precision_score: (train=0.206, test=0.211) f1_score: (train=0.364, test=0.372) precision_score: (train=0.248, test=0.253) recall_score: (train=0.690, test=0.703) roc_auc_score: (train=0.712, test=0.720) total time= 2.0s
[CV 3/3] END svm__C=29.940174718450645, svm__gamma=0.0005499160442140829, svm__kernel=rbf; accuracy_score: (train=0.728, test=0.728) average_precision_score: (train=0.210, test=0.203) f1_score: (train=0.369, test=0.361) precision_score: (train=0.250, test=0.245) recall_score: (train=0.707, test=0.683) roc_auc_score: (train=0.719, test=0.708) total time= 2.0s
[CV 1/3] END svm__C=147.00433699563095, svm__gamma=0.00021035794225904132, svm__kernel=rbf; accuracy_score: (train=0.722, test=0.718) average_precision_score: (train=0.205, test=0.204) f1_score: (train=0.362, test=0.360) precision_score: (train=0.244, test=0.242) recall_score: (train=0.702, test=0.704) roc_auc_score: (train=0.713, test=0.712) total time= 2.0s
[CV 2/3] END svm__C=147.00433699563095, svm__gamma=0.00021035794225904132, svm__kernel=rbf; accuracy_score: (train=0.720, test=0.724) average_precision_score: (train=0.203, test=0.209) f1_score: (train=0.359, test=0.368) precision_score: (train=0.242, test=0.248) recall_score: (train=0.698, test=0.714) roc_auc_score: (train=0.710, test=0.720) total time= 2.1s
[CV 3/3] END svm__C=147.00433699563095, svm__gamma=0.00021035794225904132, svm__kernel=rbf; accuracy_score: (train=0.722, test=0.721) average_precision_score: (train=0.207, test=0.201) f1_score: (train=0.365, test=0.357) precision_score: (train=0.246, test=0.241) recall_score: (train=0.710, test=0.689) roc_auc_score: (train=0.716, test=0.707) total time= 2.0s
[CV 1/3] END svm__C=13.225261877758934, svm__gamma=0.0001715982581734538, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 1.9s
[CV 2/3] END svm__C=13.225261877758934, svm__gamma=0.0001715982581734538, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 1.9s
[CV 3/3] END svm__C=13.225261877758934, svm__gamma=0.0001715982581734538, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.9s
[CV 1/3] END svm__C=78.14989008992033, svm__gamma=0.0001236246115404309, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 1.9s
[CV 2/3] END svm__C=78.14989008992033, svm__gamma=0.0001236246115404309, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 2.0s
[CV 3/3] END svm__C=78.14989008992033, svm__gamma=0.0001236246115404309, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.9s
[CV 1/3] END svm__C=1.21323761079222, svm__gamma=0.00014924731137151486, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 2.1s
[CV 2/3] END svm__C=1.21323761079222, svm__gamma=0.00014924731137151486, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 2.3s
[CV 3/3] END svm__C=1.21323761079222, svm__gamma=0.00014924731137151486, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 2.1s
[CV 1/3] END svm__C=18.94483690832992, svm__gamma=0.000205239629963043, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 1.8s
[CV 2/3] END svm__C=18.94483690832992, svm__gamma=0.000205239629963043, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 1.9s
[CV 3/3] END svm__C=18.94483690832992, svm__gamma=0.000205239629963043, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.9s
[CV 1/3] END svm__C=178.293747881023, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 2.9min
[CV 2/3] END svm__C=178.293747881023, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 6.4min
[CV 3/3] END svm__C=178.293747881023, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 3.3min
[CV 1/3] END svm__C=679.9548167126571, svm__gamma=0.00031756795756268224, svm__kernel=rbf; accuracy_score: (train=0.803, test=0.792) average_precision_score: (train=0.250, test=0.222) f1_score: (train=0.432, test=0.396) precision_score: (train=0.320, test=0.294) recall_score: (train=0.665, test=0.606) roc_auc_score: (train=0.743, test=0.710) total time= 2.3s
[CV 2/3] END svm__C=679.9548167126571, svm__gamma=0.00031756795756268224, svm__kernel=rbf; accuracy_score: (train=0.783, test=0.781) average_precision_score: (train=0.234, test=0.233) f1_score: (train=0.409, test=0.407) precision_score: (train=0.295, test=0.293) recall_score: (train=0.665, test=0.666) roc_auc_score: (train=0.732, test=0.731) total time= 2.4s
[CV 3/3] END svm__C=679.9548167126571, svm__gamma=0.00031756795756268224, svm__kernel=rbf; accuracy_score: (train=0.786, test=0.783) average_precision_score: (train=0.240, test=0.227) f1_score: (train=0.416, test=0.400) precision_score: (train=0.301, test=0.291) recall_score: (train=0.676, test=0.644) roc_auc_score: (train=0.738, test=0.722) total time= 2.4s
[CV 1/3] END svm__C=50.733179220704606, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 56.0s
[CV 2/3] END svm__C=50.733179220704606, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 1.8min
[CV 3/3] END svm__C=50.733179220704606, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.1min
[CV 1/3] END svm__C=8.95068854731697, svm__gamma=0.0002599119286878712, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 1.9s
[CV 2/3] END svm__C=8.95068854731697, svm__gamma=0.0002599119286878712, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 1.9s
[CV 3/3] END svm__C=8.95068854731697, svm__gamma=0.0002599119286878712, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.9s
[CV 1/3] END svm__C=562.4656709460996, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 7.9min
[CV 2/3] END svm__C=562.4656709460996, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time=14.1min
[CV 3/3] END svm__C=562.4656709460996, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 9.6min
[CV 1/3] END svm__C=28.125664555627026, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 32.9s
[CV 2/3] END svm__C=28.125664555627026, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 1.1min
[CV 3/3] END svm__C=28.125664555627026, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 33.0s
[CV 1/3] END svm__C=4.994681391083983, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 7.7s
[CV 2/3] END svm__C=4.994681391083983, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 12.3s
[CV 3/3] END svm__C=4.994681391083983, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 8.4s
[CV 1/3] END svm__C=68.97370850315686, svm__gamma=0.0001320166372062253, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 1.9s
[CV 2/3] END svm__C=68.97370850315686, svm__gamma=0.0001320166372062253, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 2.1s
[CV 3/3] END svm__C=68.97370850315686, svm__gamma=0.0001320166372062253, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.9s
[CV 1/3] END svm__C=247.2778560720801, svm__gamma=0.0001917356783593693, svm__kernel=rbf; accuracy_score: (train=0.728, test=0.724) average_precision_score: (train=0.207, test=0.206) f1_score: (train=0.366, test=0.364) precision_score: (train=0.248, test=0.246) recall_score: (train=0.699, test=0.701) roc_auc_score: (train=0.715, test=0.714) total time= 2.0s
[CV 2/3] END svm__C=247.2778560720801, svm__gamma=0.0001917356783593693, svm__kernel=rbf; accuracy_score: (train=0.733, test=0.735) average_precision_score: (train=0.207, test=0.212) f1_score: (train=0.367, test=0.373) precision_score: (train=0.250, test=0.254) recall_score: (train=0.689, test=0.701) roc_auc_score: (train=0.713, test=0.720) total time= 2.2s
[CV 3/3] END svm__C=247.2778560720801, svm__gamma=0.0001917356783593693, svm__kernel=rbf; accuracy_score: (train=0.729, test=0.730) average_precision_score: (train=0.210, test=0.203) f1_score: (train=0.369, test=0.361) precision_score: (train=0.250, test=0.246) recall_score: (train=0.704, test=0.680) roc_auc_score: (train=0.718, test=0.708) total time= 2.1s
[CV 1/3] END svm__C=10.67312267901468, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 13.6s
[CV 2/3] END svm__C=10.67312267901468, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 23.6s
[CV 3/3] END svm__C=10.67312267901468, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 14.9s
[CV 1/3] END svm__C=10.529019888421411, svm__gamma=0.0004889585069858854, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.205, test=0.203) f1_score: (train=0.362, test=0.359) precision_score: (train=0.244, test=0.241) recall_score: (train=0.702, test=0.704) roc_auc_score: (train=0.713, test=0.711) total time= 1.9s
[CV 2/3] END svm__C=10.529019888421411, svm__gamma=0.0004889585069858854, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.357, test=0.366) precision_score: (train=0.240, test=0.247) recall_score: (train=0.696, test=0.712) roc_auc_score: (train=0.708, test=0.718) total time= 1.9s
[CV 3/3] END svm__C=10.529019888421411, svm__gamma=0.0004889585069858854, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.9s
[CV 1/3] END svm__C=423.029374725911, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 6.0min
[CV 2/3] END svm__C=423.029374725911, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time=12.4min
[CV 3/3] END svm__C=423.029374725911, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 7.4min
[CV 1/3] END svm__C=23.45544898118272, svm__kernel=linear; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 25.8s
[CV 2/3] END svm__C=23.45544898118272, svm__kernel=linear; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 1.0min
[CV 3/3] END svm__C=23.45544898118272, svm__kernel=linear; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 30.6s
[CV 1/3] END svm__C=57.25450413061507, svm__gamma=0.00042160281513632935, svm__kernel=rbf; accuracy_score: (train=0.731, test=0.726) average_precision_score: (train=0.208, test=0.205) f1_score: (train=0.368, test=0.363) precision_score: (train=0.250, test=0.246) recall_score: (train=0.694, test=0.693) roc_auc_score: (train=0.715, test=0.712) total time= 2.0s
[CV 2/3] END svm__C=57.25450413061507, svm__gamma=0.00042160281513632935, svm__kernel=rbf; accuracy_score: (train=0.737, test=0.738) average_precision_score: (train=0.210, test=0.214) f1_score: (train=0.371, test=0.376) precision_score: (train=0.254, test=0.257) recall_score: (train=0.689, test=0.701) roc_auc_score: (train=0.716, test=0.722) total time= 2.0s
[CV 3/3] END svm__C=57.25450413061507, svm__gamma=0.00042160281513632935, svm__kernel=rbf; accuracy_score: (train=0.732, test=0.732) average_precision_score: (train=0.211, test=0.205) f1_score: (train=0.371, test=0.363) precision_score: (train=0.252, test=0.248) recall_score: (train=0.702, test=0.680) roc_auc_score: (train=0.719, test=0.709) total time= 2.0s
[CV 1/3] END svm__C=38.921183136758046, svm__gamma=0.00014975201017109607, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.716) average_precision_score: (train=0.204, test=0.203) f1_score: (train=0.361, test=0.359) precision_score: (train=0.243, test=0.241) recall_score: (train=0.701, test=0.704) roc_auc_score: (train=0.712, test=0.711) total time= 1.9s
[CV 2/3] END svm__C=38.921183136758046, svm__gamma=0.00014975201017109607, svm__kernel=rbf; accuracy_score: (train=0.718, test=0.722) average_precision_score: (train=0.202, test=0.208) f1_score: (train=0.358, test=0.366) precision_score: (train=0.241, test=0.247) recall_score: (train=0.697, test=0.712) roc_auc_score: (train=0.709, test=0.718) total time= 2.0s
[CV 3/3] END svm__C=38.921183136758046, svm__gamma=0.00014975201017109607, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.720) average_precision_score: (train=0.206, test=0.200) f1_score: (train=0.363, test=0.356) precision_score: (train=0.244, test=0.240) recall_score: (train=0.708, test=0.689) roc_auc_score: (train=0.715, test=0.706) total time= 1.9s
[CV 1/3] END svm__C=195.45578290561392, svm__gamma=0.0001752533686433448, svm__kernel=rbf; accuracy_score: (train=0.722, test=0.718) average_precision_score: (train=0.205, test=0.204) f1_score: (train=0.362, test=0.360) precision_score: (train=0.244, test=0.242) recall_score: (train=0.702, test=0.704) roc_auc_score: (train=0.713, test=0.712) total time= 2.1s
[CV 2/3] END svm__C=195.45578290561392, svm__gamma=0.0001752533686433448, svm__kernel=rbf; accuracy_score: (train=0.719, test=0.723) average_precision_score: (train=0.202, test=0.209) f1_score: (train=0.359, test=0.368) precision_score: (train=0.241, test=0.247) recall_score: (train=0.698, test=0.714) roc_auc_score: (train=0.710, test=0.719) total time= 2.2s
[CV 3/3] END svm__C=195.45578290561392, svm__gamma=0.0001752533686433448, svm__kernel=rbf; accuracy_score: (train=0.721, test=0.720) average_precision_score: (train=0.207, test=0.201) f1_score: (train=0.365, test=0.356) precision_score: (train=0.245, test=0.240) recall_score: (train=0.710, test=0.689) roc_auc_score: (train=0.716, test=0.707) total time= 2.1s
SVM (RandomSearch) took 7252.51 seconds for 3 CV folds.
best score: {'svm__C': 679.9548167126571, 'svm__gamma': 0.00031756795756268224, 'svm__kernel': 'rbf'}, best score: 0.7211790836233827
def plot_confusion_matrix(cf_matrix, target_names=None):
group_names = ['TN', 'FP', 'FN', 'TP']
group_counts = ["{0:.0f}".format(value) for value in
cf_matrix.flatten()]
group_percentages = ["{0:.2%}".format(value) for value in
cf_matrix.flatten()/np.sum(cf_matrix)]
labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in
zip(group_names,
group_counts,
group_percentages)
]
labels = np.asarray(labels).reshape(2, 2)
plt.figure(figsize = (14, 8))
sns.heatmap(cf_matrix, annot=labels, fmt='', cmap='Blues', annot_kws={"size": 14})
if target_names:
tick_marks = range(len(target_names))
plt.xticks(tick_marks, target_names)
plt.yticks(tick_marks, target_names)
precision = cf_matrix[1, 1] / sum(cf_matrix[:, 1])
recall = cf_matrix[1, 1] / sum(cf_matrix[1,:])
accuracy = np.trace(cf_matrix) / float(np.sum(cf_matrix))
f1_score = 2 * precision * recall / (precision + recall)
stats_text = "\nPrecision.={:0.3f}\nRecall{:0.3f}\n\nAccuracy={:0.3f}\nF1 Score={:0.3f}".format(
precision, recall, accuracy, f1_score)
plt.xlabel('Predicted label {}'.format(stats_text))
plt.ylabel("True Label")
plt.show()
y_pred_svm = svm_best_model.predict(X_train)
conf_matrix = confusion_matrix(y_train.ravel(), y_pred_svm)
plot_confusion_matrix(conf_matrix, target_names= ['Not Subscribed to term deposit', 'Subscribed to term deposit'])
svm_model_selection.cv_results_
{'mean_fit_time': array([ 1.05025466, 0.96830765, 293.83060439, 158.87335817,
0.86085598, 1.05450463, 167.30085182, 1.01373204,
0.85539714, 117.52681748, 0.9350187 , 1.01454131,
0.86615554, 0.92007136, 0.90458592, 0.86722787,
250.06133938, 1.36155486, 76.95310855, 0.86241674,
630.86465716, 43.52959538, 9.06285739, 0.91291944,
1.08985551, 16.97333662, 0.87658985, 516.29614147,
38.50622892, 0.97705626, 0.89802027, 1.03991922]),
'mean_score_time': array([1.03323754, 1.04526631, 0.39837607, 0.39350565, 1.08787688,
1.06155554, 0.40133023, 1.00534272, 1.07570481, 0.39426335,
1.03642384, 1.01602705, 1.02935918, 1.00253352, 1.28196875,
1.01361783, 0.39660406, 0.98563433, 0.4006474 , 1.05305847,
0.39563513, 0.39488816, 0.39566429, 1.04066682, 1.01246182,
0.39719017, 1.0222377 , 0.39907424, 0.39618707, 1.01317207,
1.02336796, 1.08561317]),
'mean_test_accuracy_score': array([0.7318361 , 0.73942329, 0.71945372, 0.71945372, 0.71945372,
0.73863422, 0.71945372, 0.71939303, 0.71939303, 0.71945372,
0.72728378, 0.72091045, 0.71945372, 0.71945372, 0.71945372,
0.71945372, 0.71945372, 0.78525024, 0.71945372, 0.71945372,
0.71945372, 0.71945372, 0.71945372, 0.71945372, 0.7294082 ,
0.71945372, 0.71939303, 0.71945372, 0.71945372, 0.7320182 ,
0.71945372, 0.72018208]),
'mean_test_average_precision_score': array([0.20780125, 0.2109832 , 0.20387235, 0.20387235, 0.20387235,
0.21067981, 0.20387235, 0.20398062, 0.20370503, 0.20387235,
0.20639993, 0.20477873, 0.20387235, 0.20387235, 0.20387235,
0.20387235, 0.20387235, 0.22745232, 0.20387235, 0.20387235,
0.20387235, 0.20387235, 0.20387235, 0.20387235, 0.20713767,
0.20387235, 0.20384112, 0.20387235, 0.20387235, 0.20789813,
0.20387235, 0.20439386]),
'mean_test_f1_score': array([0.36745284, 0.37271341, 0.360538 , 0.360538 , 0.360538 ,
0.37220152, 0.360538 , 0.36066409, 0.36031271, 0.360538 ,
0.364972 , 0.36191563, 0.360538 , 0.360538 , 0.360538 ,
0.360538 , 0.360538 , 0.40103145, 0.360538 , 0.360538 ,
0.360538 , 0.360538 , 0.360538 , 0.360538 , 0.36623921,
0.360538 , 0.36048876, 0.360538 , 0.360538 , 0.36760964,
0.360538 , 0.36131245]),
'mean_test_precision_score': array([0.250251 , 0.25575117, 0.24255452, 0.24255452, 0.24255452,
0.25520052, 0.24255452, 0.24260454, 0.24241464, 0.24255452,
0.24740416, 0.24373771, 0.24255452, 0.24255452, 0.24255452,
0.24255452, 0.24255452, 0.29247599, 0.24255452, 0.24255452,
0.24255452, 0.24255452, 0.24255452, 0.24255452, 0.2487781 ,
0.24255452, 0.24251018, 0.24255452, 0.24255452, 0.25039384,
0.24255452, 0.24319062]),
'mean_test_recall_score': array([0.6912644 , 0.68695638, 0.70204056, 0.70204056, 0.70204056,
0.68749488, 0.70204056, 0.70257906, 0.70150206, 0.70204056,
0.69557504, 0.70257906, 0.70204056, 0.70204056, 0.70204056,
0.70204056, 0.70204056, 0.63847281, 0.70204056, 0.70204056,
0.70204056, 0.70204056, 0.70204056, 0.70204056, 0.69395779,
0.70204056, 0.70204056, 0.70204056, 0.70204056, 0.69126527,
0.70204056, 0.70257906]),
'mean_test_roc_auc_score': array([0.71412526, 0.71652012, 0.71185207, 0.71185207, 0.71185207,
0.71631054, 0.71185207, 0.71205292, 0.71158282, 0.71185207,
0.71344181, 0.71290797, 0.71185207, 0.71185207, 0.71185207,
0.71185207, 0.71185207, 0.72117908, 0.71185207, 0.71185207,
0.71185207, 0.71185207, 0.71185207, 0.71185207, 0.71393286,
0.71185207, 0.71181787, 0.71185207, 0.71185207, 0.7142283 ,
0.71185207, 0.71249755]),
'mean_train_accuracy_score': array([0.73308048, 0.74176033, 0.71945372, 0.71945372, 0.71945372,
0.74048569, 0.71945372, 0.71960546, 0.71948407, 0.71945372,
0.72764795, 0.7210015 , 0.71945372, 0.71945372, 0.71945372,
0.71945372, 0.71945372, 0.79098648, 0.71945372, 0.71948407,
0.71945372, 0.71945372, 0.71945372, 0.71945372, 0.72968136,
0.71945372, 0.71948407, 0.71945372, 0.71945372, 0.73341432,
0.71945372, 0.72057662]),
'mean_train_average_precision_score': array([0.20942853, 0.21330007, 0.20385588, 0.20385588, 0.20385588,
0.21269677, 0.20385588, 0.20421057, 0.20380377, 0.20385588,
0.20746971, 0.20494492, 0.20385588, 0.20385588, 0.20385588,
0.20385588, 0.20385588, 0.24147306, 0.20385588, 0.20387147,
0.20385588, 0.20385588, 0.20385588, 0.20385588, 0.20816523,
0.20385588, 0.20387271, 0.20385588, 0.20385588, 0.20961678,
0.20385588, 0.20472112]),
'mean_train_f1_score': array([0.36976625, 0.37605661, 0.36054129, 0.36054129, 0.36054129,
0.37509109, 0.36054129, 0.36101928, 0.36047752, 0.36054129,
0.36641367, 0.36217119, 0.36054129, 0.36054129, 0.36054129,
0.36054129, 0.36054129, 0.41910385, 0.36054129, 0.36056602,
0.36054129, 0.36054129, 0.36054129, 0.36054129, 0.36762249,
0.36054129, 0.36056622, 0.36054129, 0.36054129, 0.37005787,
0.36054129, 0.36182016]),
'mean_train_precision_score': array([0.25189645, 0.25838565, 0.24255417, 0.24255417, 0.24255417,
0.25740261, 0.24255417, 0.24285823, 0.24252866, 0.24255417,
0.24827445, 0.24390145, 0.24255417, 0.24255417, 0.24255417,
0.24255417, 0.24255417, 0.30534175, 0.24255417, 0.24257654,
0.24255417, 0.24255417, 0.24255417, 0.24255417, 0.2495958 ,
0.24255417, 0.24257666, 0.24255417, 0.24255417, 0.25216605,
0.24255417, 0.24358335]),
'mean_train_recall_score': array([0.69504125, 0.69073039, 0.7020457 , 0.7020457 , 0.7020457 ,
0.69126933, 0.7020457 , 0.70312314, 0.70177623, 0.7020457 ,
0.69908198, 0.70312314, 0.7020457 , 0.7020457 , 0.7020457 ,
0.7020457 , 0.7020457 , 0.66864023, 0.7020457 , 0.7020457 ,
0.7020457 , 0.7020457 , 0.7020457 , 0.7020457 , 0.69746581,
0.7020457 , 0.7020457 , 0.7020457 , 0.7020457 , 0.69504125,
0.7020457 , 0.70312314]),
'mean_train_roc_auc_score': array([0.71647541, 0.71948449, 0.71185465, 0.71185465, 0.71185465,
0.71900152, 0.71185465, 0.71241047, 0.71175411, 0.71185465,
0.71517817, 0.71319711, 0.71185465, 0.71185465, 0.71185465,
0.71185465, 0.71185465, 0.73757957, 0.71185465, 0.71187175,
0.71185465, 0.71185465, 0.71185465, 0.71185465, 0.71561847,
0.71185465, 0.71187175, 0.71185465, 0.71185465, 0.71666352,
0.71185465, 0.7129577 ]),
'param_svm__C': masked_array(data=[137.67844795855254, 45.07588967315673,
218.81812166859152, 113.36766867014445,
2.6293735380546712, 153.86828816779857,
121.41309541725505, 163.6764576114005,
1.7831154117152481, 80.02079756292471,
29.940174718450645, 147.00433699563095,
13.225261877758934, 78.14989008992033,
1.21323761079222, 18.94483690832992, 178.293747881023,
679.9548167126571, 50.733179220704606,
8.95068854731697, 562.4656709460996,
28.125664555627026, 4.994681391083983,
68.97370850315686, 247.2778560720801,
10.67312267901468, 10.529019888421411,
423.029374725911, 23.45544898118272, 57.25450413061507,
38.921183136758046, 195.45578290561392],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value='?',
dtype=object),
'param_svm__gamma': masked_array(data=[0.00026820750502750263, 0.0005241661481516257, --, --,
0.0002517778708684341, 0.00027451889663713023, --,
0.0001522270146549397, 0.0005156205067883839, --,
0.0005499160442140829, 0.00021035794225904132,
0.0001715982581734538, 0.0001236246115404309,
0.00014924731137151486, 0.000205239629963043, --,
0.00031756795756268224, --, 0.0002599119286878712, --,
--, --, 0.0001320166372062253, 0.0001917356783593693,
--, 0.0004889585069858854, --, --,
0.00042160281513632935, 0.00014975201017109607,
0.0001752533686433448],
mask=[False, False, True, True, False, False, True, False,
False, True, False, False, False, False, False, False,
True, False, True, False, True, True, True, False,
False, True, False, True, True, False, False, False],
fill_value='?',
dtype=object),
'param_svm__kernel': masked_array(data=['rbf', 'rbf', 'linear', 'linear', 'rbf', 'rbf',
'linear', 'rbf', 'rbf', 'linear', 'rbf', 'rbf', 'rbf',
'rbf', 'rbf', 'rbf', 'linear', 'rbf', 'linear', 'rbf',
'linear', 'linear', 'linear', 'rbf', 'rbf', 'linear',
'rbf', 'linear', 'linear', 'rbf', 'rbf', 'rbf'],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value='?',
dtype=object),
'params': [{'svm__C': 137.67844795855254,
'svm__gamma': 0.00026820750502750263,
'svm__kernel': 'rbf'},
{'svm__C': 45.07588967315673,
'svm__gamma': 0.0005241661481516257,
'svm__kernel': 'rbf'},
{'svm__C': 218.81812166859152, 'svm__kernel': 'linear'},
{'svm__C': 113.36766867014445, 'svm__kernel': 'linear'},
{'svm__C': 2.6293735380546712,
'svm__gamma': 0.0002517778708684341,
'svm__kernel': 'rbf'},
{'svm__C': 153.86828816779857,
'svm__gamma': 0.00027451889663713023,
'svm__kernel': 'rbf'},
{'svm__C': 121.41309541725505, 'svm__kernel': 'linear'},
{'svm__C': 163.6764576114005,
'svm__gamma': 0.0001522270146549397,
'svm__kernel': 'rbf'},
{'svm__C': 1.7831154117152481,
'svm__gamma': 0.0005156205067883839,
'svm__kernel': 'rbf'},
{'svm__C': 80.02079756292471, 'svm__kernel': 'linear'},
{'svm__C': 29.940174718450645,
'svm__gamma': 0.0005499160442140829,
'svm__kernel': 'rbf'},
{'svm__C': 147.00433699563095,
'svm__gamma': 0.00021035794225904132,
'svm__kernel': 'rbf'},
{'svm__C': 13.225261877758934,
'svm__gamma': 0.0001715982581734538,
'svm__kernel': 'rbf'},
{'svm__C': 78.14989008992033,
'svm__gamma': 0.0001236246115404309,
'svm__kernel': 'rbf'},
{'svm__C': 1.21323761079222,
'svm__gamma': 0.00014924731137151486,
'svm__kernel': 'rbf'},
{'svm__C': 18.94483690832992,
'svm__gamma': 0.000205239629963043,
'svm__kernel': 'rbf'},
{'svm__C': 178.293747881023, 'svm__kernel': 'linear'},
{'svm__C': 679.9548167126571,
'svm__gamma': 0.00031756795756268224,
'svm__kernel': 'rbf'},
{'svm__C': 50.733179220704606, 'svm__kernel': 'linear'},
{'svm__C': 8.95068854731697,
'svm__gamma': 0.0002599119286878712,
'svm__kernel': 'rbf'},
{'svm__C': 562.4656709460996, 'svm__kernel': 'linear'},
{'svm__C': 28.125664555627026, 'svm__kernel': 'linear'},
{'svm__C': 4.994681391083983, 'svm__kernel': 'linear'},
{'svm__C': 68.97370850315686,
'svm__gamma': 0.0001320166372062253,
'svm__kernel': 'rbf'},
{'svm__C': 247.2778560720801,
'svm__gamma': 0.0001917356783593693,
'svm__kernel': 'rbf'},
{'svm__C': 10.67312267901468, 'svm__kernel': 'linear'},
{'svm__C': 10.529019888421411,
'svm__gamma': 0.0004889585069858854,
'svm__kernel': 'rbf'},
{'svm__C': 423.029374725911, 'svm__kernel': 'linear'},
{'svm__C': 23.45544898118272, 'svm__kernel': 'linear'},
{'svm__C': 57.25450413061507,
'svm__gamma': 0.00042160281513632935,
'svm__kernel': 'rbf'},
{'svm__C': 38.921183136758046,
'svm__gamma': 0.00014975201017109607,
'svm__kernel': 'rbf'},
{'svm__C': 195.45578290561392,
'svm__gamma': 0.0001752533686433448,
'svm__kernel': 'rbf'}],
'rank_test_accuracy_score': array([ 5, 2, 10, 10, 10, 3, 10, 30, 30, 10, 7, 8, 10, 10, 10, 10, 10,
1, 10, 10, 10, 10, 10, 10, 6, 10, 30, 10, 10, 4, 10, 9],
dtype=int32),
'rank_test_average_precision_score': array([ 5, 2, 11, 11, 11, 3, 11, 10, 32, 11, 7, 8, 11, 11, 11, 11, 11,
1, 11, 11, 11, 11, 11, 11, 6, 11, 31, 11, 11, 4, 11, 9],
dtype=int32),
'rank_test_f1_score': array([ 5, 2, 11, 11, 11, 3, 11, 10, 32, 11, 7, 8, 11, 11, 11, 11, 11,
1, 11, 11, 11, 11, 11, 11, 6, 11, 31, 11, 11, 4, 11, 9],
dtype=int32),
'rank_test_precision_score': array([ 5, 2, 11, 11, 11, 3, 11, 10, 32, 11, 7, 8, 11, 11, 11, 11, 11,
1, 11, 11, 11, 11, 11, 11, 6, 11, 31, 11, 11, 4, 11, 9],
dtype=int32),
'rank_test_recall_score': array([29, 31, 4, 4, 4, 30, 4, 1, 25, 4, 26, 1, 4, 4, 4, 4, 4,
32, 4, 4, 4, 4, 4, 4, 27, 4, 4, 4, 4, 28, 4, 1],
dtype=int32),
'rank_test_roc_auc_score': array([ 5, 2, 11, 11, 11, 3, 11, 10, 32, 11, 7, 8, 11, 11, 11, 11, 11,
1, 11, 11, 11, 11, 11, 11, 6, 11, 31, 11, 11, 4, 11, 9],
dtype=int32),
'split0_test_accuracy_score': array([0.72560087, 0.73306628, 0.71649672, 0.71649672, 0.71649672,
0.73160961, 0.71649672, 0.71631464, 0.71631464, 0.71649672,
0.72159505, 0.71813547, 0.71649672, 0.71649672, 0.71649672,
0.71649672, 0.71649672, 0.79151493, 0.71649672, 0.71649672,
0.71649672, 0.71649672, 0.71649672, 0.71649672, 0.72396213,
0.71649672, 0.71631464, 0.71649672, 0.71649672, 0.72596504,
0.71649672, 0.71758922]),
'split0_test_average_precision_score': array([0.20529904, 0.20770005, 0.20299068, 0.20299068, 0.20299068,
0.20731461, 0.20299068, 0.202897 , 0.20248872, 0.20299068,
0.20482754, 0.20383856, 0.20299068, 0.20299068, 0.20299068,
0.20299068, 0.20299068, 0.22246975, 0.20299068, 0.20299068,
0.20299068, 0.20299068, 0.20299068, 0.20299068, 0.20608808,
0.20299068, 0.202897 , 0.20299068, 0.20299068, 0.20507735,
0.20299068, 0.203555 ]),
'split0_test_f1_score': array([0.36332911, 0.36755824, 0.35899547, 0.35899547, 0.35899547,
0.36683849, 0.35899547, 0.35884774, 0.3583196 , 0.35899547,
0.36211932, 0.36033058, 0.35899547, 0.35899547, 0.35899547,
0.35899547, 0.35899547, 0.39577836, 0.35899547, 0.35899547,
0.35899547, 0.35899547, 0.35899547, 0.35899547, 0.36409396,
0.35899547, 0.35884774, 0.35899547, 0.35899547, 0.36309776,
0.35899547, 0.35988444]),
'split0_test_precision_score': array([0.24599542, 0.25073573, 0.24088398, 0.24088398, 0.24088398,
0.24985372, 0.24088398, 0.24075097, 0.24046434, 0.24088398,
0.24409449, 0.24208773, 0.24088398, 0.24088398, 0.24088398,
0.24088398, 0.24088398, 0.29388715, 0.24088398, 0.24088398,
0.24088398, 0.24088398, 0.24088398, 0.24088398, 0.24589235,
0.24088398, 0.24075097, 0.24088398, 0.24088398, 0.24598624,
0.24088398, 0.24168514]),
'split0_test_recall_score': array([0.69466882, 0.68820679, 0.70436187, 0.70436187, 0.70436187,
0.68982229, 0.70436187, 0.70436187, 0.70274637, 0.70436187,
0.70113086, 0.70436187, 0.70436187, 0.70436187, 0.70436187,
0.70436187, 0.70436187, 0.60581583, 0.70436187, 0.70436187,
0.70436187, 0.70436187, 0.70436187, 0.70436187, 0.70113086,
0.70436187, 0.70436187, 0.70436187, 0.70436187, 0.69305331,
0.70436187, 0.70436187]),
'split0_test_roc_auc_score': array([0.71209944, 0.7134857 , 0.71120002, 0.71120002, 0.71120002,
0.71337 , 0.71120002, 0.71109742, 0.71039227, 0.71120002,
0.7126627 , 0.71212348, 0.71120002, 0.71120002, 0.71120002,
0.71120002, 0.71120002, 0.71045973, 0.71120002, 0.71120002,
0.71120002, 0.71120002, 0.71120002, 0.71120002, 0.71399658,
0.71120002, 0.71109742, 0.71120002, 0.71120002, 0.71159951,
0.71120002, 0.71181566]),
'split0_train_accuracy_score': array([0.73076573, 0.73850496, 0.72093235, 0.72093235, 0.72093235,
0.73723026, 0.72093235, 0.7210234 , 0.7210234 , 0.72093235,
0.72612219, 0.7219339 , 0.72093235, 0.72093235, 0.72093235,
0.72093235, 0.72093235, 0.80315032, 0.72093235, 0.72093235,
0.72093235, 0.72093235, 0.72093235, 0.72093235, 0.72757899,
0.72093235, 0.7210234 , 0.72093235, 0.72093235, 0.73131203,
0.72093235, 0.7217518 ]),
'split0_train_average_precision_score': array([0.20813338, 0.21166913, 0.20428668, 0.20428668, 0.20428668,
0.21114785, 0.20428668, 0.2045414 , 0.20433459, 0.20428668,
0.20706103, 0.20502253, 0.20428668, 0.20428668, 0.20428668,
0.20428668, 0.20428668, 0.25040847, 0.20428668, 0.20428668,
0.20428668, 0.20428668, 0.20428668, 0.20428668, 0.20743456,
0.20428668, 0.2045414 , 0.20428668, 0.20428668, 0.20822466,
0.20428668, 0.20492609]),
'split0_train_f1_score': array([0.36775711, 0.37347295, 0.36132528, 0.36132528, 0.36132528,
0.3726087 , 0.36132528, 0.36166667, 0.36140058, 0.36132528,
0.36566849, 0.36242171, 0.36132528, 0.36132528, 0.36132528,
0.36132528, 0.36132528, 0.43194955, 0.36132528, 0.36132528,
0.36132528, 0.36132528, 0.36132528, 0.36132528, 0.36637018,
0.36132528, 0.36166667, 0.36132528, 0.36132528, 0.36795888,
0.36132528, 0.36227045]),
'split0_train_precision_score': array([0.25 , 0.25575142, 0.24340258, 0.24340258, 0.24340258,
0.254832 , 0.24340258, 0.24361493, 0.24347094, 0.24340258,
0.24736091, 0.24430059, 0.24340258, 0.24340258, 0.24340258,
0.24340258, 0.24340258, 0.31996886, 0.24340258, 0.24340258,
0.24340258, 0.24340258, 0.24340258, 0.24340258, 0.2482066 ,
0.24340258, 0.24361493, 0.24340258, 0.24340258, 0.25029138,
0.24340258, 0.24416315]),
'split0_train_recall_score': array([0.6952304 , 0.69199677, 0.70088925, 0.70088925, 0.70088925,
0.69280517, 0.70088925, 0.70169766, 0.70088925, 0.70088925,
0.70088925, 0.70169766, 0.70088925, 0.70088925, 0.70088925,
0.70088925, 0.70088925, 0.66451091, 0.70088925, 0.70088925,
0.70088925, 0.70088925, 0.70088925, 0.70088925, 0.69927243,
0.70088925, 0.70169766, 0.70088925, 0.70088925, 0.69442199,
0.70088925, 0.70169766]),
'split0_train_roc_auc_score': array([0.7152532 , 0.71820236, 0.71218277, 0.71218277, 0.71218277,
0.71783702, 0.71218277, 0.71258698, 0.71223408, 0.71218277,
0.71510705, 0.71310001, 0.71218277, 0.71218277, 0.71218277,
0.71218277, 0.71218277, 0.74262894, 0.71218277, 0.71218277,
0.71218277, 0.71218277, 0.71218277, 0.71218277, 0.7152221 ,
0.71218277, 0.71258698, 0.71218277, 0.71218277, 0.71520812,
0.71218277, 0.7129974 ]),
'split1_test_accuracy_score': array([0.73834669, 0.74708667, 0.72232338, 0.72232338, 0.72232338,
0.74690459, 0.72232338, 0.72232338, 0.72232338, 0.72232338,
0.73270211, 0.72396213, 0.72232338, 0.72232338, 0.72232338,
0.72232338, 0.72232338, 0.78131828, 0.72232338, 0.72232338,
0.72232338, 0.72232338, 0.72232338, 0.72232338, 0.73470503,
0.72232338, 0.72232338, 0.72232338, 0.72232338, 0.7381646 ,
0.72232338, 0.72305171]),
'split1_test_average_precision_score': array([0.21416626, 0.21766988, 0.2081299 , 0.2081299 , 0.2081299 ,
0.21755753, 0.2081299 , 0.2085484 , 0.2081299 , 0.2081299 ,
0.21133613, 0.20943941, 0.2081299 , 0.2081299 , 0.2081299 ,
0.2081299 , 0.2081299 , 0.23272884, 0.2081299 , 0.2081299 ,
0.2081299 , 0.2081299 , 0.2081299 , 0.2081299 , 0.21205042,
0.2081299 , 0.2081299 , 0.2081299 , 0.2081299 , 0.21405927,
0.2081299 , 0.2089433 ]),
'split1_test_f1_score': array([0.37657267, 0.38239217, 0.36643124, 0.36643124, 0.36643124,
0.38222222, 0.36643124, 0.36695724, 0.36643124, 0.36643124,
0.37211292, 0.36833333, 0.36643124, 0.36643124, 0.36643124,
0.36643124, 0.36643124, 0.40691358, 0.36643124, 0.36643124,
0.36643124, 0.36643124, 0.36643124, 0.36643124, 0.37333333,
0.36643124, 0.36643124, 0.36643124, 0.36643124, 0.37640937,
0.36643124, 0.36756757]),
'split1_test_precision_score': array([0.257414 , 0.26380368, 0.2466443 , 0.2466443 , 0.2466443 ,
0.26364194, 0.2466443 , 0.24692737, 0.2466443 , 0.2466443 ,
0.2530541 , 0.24817518, 0.2466443 , 0.2466443 , 0.2466443 ,
0.2466443 , 0.2466443 , 0.29302987, 0.2466443 , 0.2466443 ,
0.2466443 , 0.2466443 , 0.2466443 , 0.2466443 , 0.25439625,
0.2466443 , 0.2466443 , 0.2466443 , 0.2466443 , 0.25726141,
0.2466443 , 0.2474804 ]),
'split1_test_recall_score': array([0.70113086, 0.69466882, 0.71243942, 0.71243942, 0.71243942,
0.69466882, 0.71243942, 0.71405493, 0.71243942, 0.71243942,
0.70274637, 0.71405493, 0.71243942, 0.71243942, 0.71243942,
0.71243942, 0.71243942, 0.66558966, 0.71243942, 0.71243942,
0.71243942, 0.71243942, 0.71243942, 0.71243942, 0.70113086,
0.71243942, 0.71243942, 0.71243942, 0.71243942, 0.70113086,
0.71243942, 0.71405493]),
'split1_test_roc_auc_score': array([0.72210247, 0.72420697, 0.71800916, 0.71800916, 0.71800916,
0.72410437, 0.71800916, 0.71871431, 0.71800916, 0.71800916,
0.71962683, 0.71963777, 0.71800916, 0.71800916, 0.71800916,
0.71800916, 0.71800916, 0.73080427, 0.71800916, 0.71800916,
0.71800916, 0.71800916, 0.71800916, 0.71800916, 0.72005034,
0.71800916, 0.71800916, 0.71800916, 0.71800916, 0.72199986,
0.71800916, 0.71912473]),
'split1_train_accuracy_score': array([0.73713922, 0.74797414, 0.71801876, 0.71801876, 0.71801876,
0.74715469, 0.71801876, 0.71810981, 0.71801876, 0.71801876,
0.72903578, 0.7195666 , 0.71801876, 0.71801876, 0.71801876,
0.71801876, 0.71801876, 0.78339252, 0.71801876, 0.71810981,
0.71801876, 0.71801876, 0.71801876, 0.71801876, 0.73258672,
0.71801876, 0.71801876, 0.71801876, 0.71801876, 0.73732131,
0.71801876, 0.7190203 ]),
'split1_train_average_precision_score': array([0.20980199, 0.21423272, 0.20174476, 0.20174476, 0.20174476,
0.21395606, 0.20174476, 0.20199595, 0.20154054, 0.20174476,
0.20570464, 0.20274884, 0.20174476, 0.20174476, 0.20174476,
0.20174476, 0.20174476, 0.23409268, 0.20174476, 0.20179154,
0.20174476, 0.20174476, 0.20174476, 0.20174476, 0.2074473 ,
0.20174476, 0.20154054, 0.20174476, 0.20174476, 0.21012165,
0.20174476, 0.20246572]),
'split1_train_f1_score': array([0.3708869 , 0.37825696, 0.35760216, 0.35760216, 0.35760216,
0.3777728 , 0.35760216, 0.35794276, 0.35733555, 0.35760216,
0.3643742 , 0.35913442, 0.35760216, 0.35760216, 0.35760216,
0.35760216, 0.35760216, 0.4089441 , 0.35760216, 0.35767635,
0.35760216, 0.35760216, 0.35760216, 0.35760216, 0.36716225,
0.35760216, 0.35733555, 0.35760216, 0.35760216, 0.37132273,
0.35760216, 0.35868662]),
'split1_train_precision_score': array([0.25387828, 0.26189736, 0.24051339, 0.24051339, 0.24051339,
0.26131432, 0.24051339, 0.24072524, 0.24036851, 0.24051339,
0.24760522, 0.24180443, 0.24051339, 0.24051339, 0.24051339,
0.24051339, 0.24051339, 0.29519369, 0.24051339, 0.24058052,
0.24051339, 0.24051339, 0.24051339, 0.24051339, 0.25029377,
0.24051339, 0.24036851, 0.24051339, 0.24051339, 0.25417661,
0.24051339, 0.2413986 ]),
'split1_train_recall_score': array([0.68795473, 0.68067906, 0.69684721, 0.69684721, 0.69684721,
0.68148747, 0.69684721, 0.69765562, 0.6960388 , 0.69684721,
0.68957154, 0.69765562, 0.69684721, 0.69684721, 0.69684721,
0.69684721, 0.69684721, 0.66531932, 0.69684721, 0.69684721,
0.69684721, 0.69684721, 0.69684721, 0.69684721, 0.68876314,
0.69684721, 0.6960388 , 0.69684721, 0.69684721, 0.68876314,
0.69684721, 0.69765562]),
'split1_train_roc_auc_score': array([0.71566831, 0.71859728, 0.70877657, 0.70877657, 0.70877657,
0.71848845, 0.70877657, 0.70918077, 0.70842367, 0.70877657,
0.71180814, 0.71000162, 0.70877657, 0.70877657, 0.70877657,
0.70877657, 0.70877657, 0.73184907, 0.70877657, 0.70882787,
0.70877657, 0.70877657, 0.70877657, 0.70877657, 0.71345606,
0.70877657, 0.70842367, 0.70877657, 0.70877657, 0.71612382,
0.70877657, 0.70969381]),
'split2_test_accuracy_score': array([0.73156074, 0.73811692, 0.71954107, 0.71954107, 0.71954107,
0.73738845, 0.71954107, 0.71954107, 0.71954107, 0.71954107,
0.72755418, 0.72063376, 0.71954107, 0.71954107, 0.71954107,
0.71954107, 0.71954107, 0.7829175 , 0.71954107, 0.71954107,
0.71954107, 0.71954107, 0.71954107, 0.71954107, 0.72955746,
0.71954107, 0.71954107, 0.71954107, 0.71954107, 0.73192497,
0.71954107, 0.7199053 ]),
'split2_test_average_precision_score': array([0.20393846, 0.20757966, 0.20049646, 0.20049646, 0.20049646,
0.20716729, 0.20049646, 0.20049646, 0.20049646, 0.20049646,
0.20303613, 0.20105821, 0.20049646, 0.20049646, 0.20049646,
0.20049646, 0.20049646, 0.22715837, 0.20049646, 0.20049646,
0.20049646, 0.20049646, 0.20049646, 0.20049646, 0.20327453,
0.20049646, 0.20049646, 0.20049646, 0.20049646, 0.20455776,
0.20049646, 0.20068329]),
'split2_test_f1_score': array([0.36245675, 0.36818981, 0.35618729, 0.35618729, 0.35618729,
0.36754386, 0.35618729, 0.35618729, 0.35618729, 0.35618729,
0.36068376, 0.35708298, 0.35618729, 0.35618729, 0.35618729,
0.35618729, 0.35618729, 0.40040241, 0.35618729, 0.35618729,
0.35618729, 0.35618729, 0.35618729, 0.35618729, 0.36129032,
0.35618729, 0.35618729, 0.35618729, 0.35618729, 0.3633218 ,
0.35618729, 0.35648536]),
'split2_test_precision_score': array([0.24734357, 0.25271411, 0.24013529, 0.24013529, 0.24013529,
0.2521059 , 0.24013529, 0.24013529, 0.24013529, 0.24013529,
0.24506388, 0.24095023, 0.24013529, 0.24013529, 0.24013529,
0.24013529, 0.24013529, 0.29051095, 0.24013529, 0.24013529,
0.24013529, 0.24013529, 0.24013529, 0.24013529, 0.24604569,
0.24013529, 0.24013529, 0.24013529, 0.24013529, 0.24793388,
0.24013529, 0.24040632]),
'split2_test_recall_score': array([0.67799353, 0.67799353, 0.68932039, 0.68932039, 0.68932039,
0.67799353, 0.68932039, 0.68932039, 0.68932039, 0.68932039,
0.6828479 , 0.68932039, 0.68932039, 0.68932039, 0.68932039,
0.68932039, 0.68932039, 0.64401294, 0.68932039, 0.68932039,
0.68932039, 0.68932039, 0.68932039, 0.68932039, 0.67961165,
0.68932039, 0.68932039, 0.68932039, 0.68932039, 0.67961165,
0.68932039, 0.68932039]),
'split2_test_roc_auc_score': array([0.70817386, 0.71186769, 0.70634704, 0.70634704, 0.70634704,
0.71145726, 0.70634704, 0.70634704, 0.70634704, 0.70634704,
0.70803589, 0.70696268, 0.70634704, 0.70634704, 0.70634704,
0.70634704, 0.70634704, 0.72227325, 0.70634704, 0.70634704,
0.70634704, 0.70634704, 0.70634704, 0.70634704, 0.70775165,
0.70634704, 0.70634704, 0.70634704, 0.70634704, 0.70908553,
0.70634704, 0.70655225]),
'split2_train_accuracy_score': array([0.73133649, 0.73880189, 0.71941005, 0.71941005, 0.71941005,
0.7370721 , 0.71941005, 0.71968318, 0.71941005, 0.71941005,
0.72778587, 0.72150401, 0.71941005, 0.71941005, 0.71941005,
0.71941005, 0.71941005, 0.78641661, 0.71941005, 0.71941005,
0.71941005, 0.71941005, 0.71941005, 0.71941005, 0.72887837,
0.71941005, 0.71941005, 0.71941005, 0.71941005, 0.73160961,
0.71941005, 0.72095776]),
'split2_train_average_precision_score': array([0.21035023, 0.21399837, 0.20553618, 0.20553618, 0.20553618,
0.21298639, 0.20553618, 0.20609435, 0.20553618, 0.20553618,
0.20964344, 0.2070634 , 0.20553618, 0.20553618, 0.20553618,
0.20553618, 0.20553618, 0.23991804, 0.20553618, 0.20553618,
0.20553618, 0.20553618, 0.20553618, 0.20553618, 0.20961383,
0.20553618, 0.20553618, 0.20553618, 0.20553618, 0.21050402,
0.20553618, 0.20677155]),
'split2_train_f1_score': array([0.37065472, 0.3764399 , 0.36269644, 0.36269644, 0.36269644,
0.37489177, 0.36269644, 0.36344842, 0.36269644, 0.36269644,
0.36919831, 0.36495744, 0.36269644, 0.36269644, 0.36269644,
0.36269644, 0.36269644, 0.41641791, 0.36269644, 0.36269644,
0.36269644, 0.36269644, 0.36269644, 0.36269644, 0.36933503,
0.36269644, 0.36269644, 0.36269644, 0.36269644, 0.37089202,
0.36269644, 0.36450342]),
'split2_train_precision_score': array([0.25181107, 0.25750818, 0.24374653, 0.24374653, 0.24374653,
0.2560615 , 0.24374653, 0.24423451, 0.24374653, 0.24374653,
0.24985722, 0.24559933, 0.24374653, 0.24374653, 0.24374653,
0.24374653, 0.24374653, 0.30086269, 0.24374653, 0.24374653,
0.24374653, 0.24374653, 0.24374653, 0.24374653, 0.25028703,
0.24374653, 0.24374653, 0.24374653, 0.24374653, 0.25203016,
0.24374653, 0.24518828]),
'split2_train_recall_score': array([0.70193861, 0.69951535, 0.70840065, 0.70840065, 0.70840065,
0.69951535, 0.70840065, 0.71001616, 0.70840065, 0.70840065,
0.70678514, 0.71001616, 0.70840065, 0.70840065, 0.70840065,
0.70840065, 0.70840065, 0.67609047, 0.70840065, 0.70840065,
0.70840065, 0.70840065, 0.70840065, 0.70840065, 0.70436187,
0.70840065, 0.70840065, 0.70840065, 0.70840065, 0.70193861,
0.70840065, 0.71001616]),
'split2_train_roc_auc_score': array([0.7185047 , 0.72165384, 0.71460459, 0.71460459, 0.71460459,
0.72067908, 0.71460459, 0.71546365, 0.71460459, 0.71460459,
0.71861933, 0.71648971, 0.71460459, 0.71460459, 0.71460459,
0.71460459, 0.71460459, 0.73826071, 0.71460459, 0.71460459,
0.71460459, 0.71460459, 0.71460459, 0.71460459, 0.71817724,
0.71460459, 0.71460459, 0.71460459, 0.71460459, 0.71865861,
0.71460459, 0.71618189]),
'std_fit_time': array([2.97678023e-02, 2.31863444e-02, 9.24166803e+01, 5.23058512e+01,
1.67026174e-02, 3.08859720e-02, 4.83287906e+01, 3.33386283e-02,
1.47510888e-02, 4.25310012e+01, 1.64296275e-02, 1.90685913e-02,
1.38970903e-02, 3.28638376e-02, 4.07948139e-03, 1.87139895e-02,
9.30969376e+01, 2.58746501e-02, 2.38119671e+01, 1.05388244e-02,
1.58379507e+02, 1.54944431e+01, 2.00022270e+00, 2.48939543e-02,
3.38547210e-02, 4.40219160e+00, 1.62696960e-02, 1.66470769e+02,
1.52511890e+01, 2.22412152e-02, 1.99572133e-02, 2.57084695e-02]),
'std_score_time': array([0.04097246, 0.03585187, 0.00744359, 0.00733791, 0.0191123 ,
0.0404021 , 0.00631185, 0.02630991, 0.04797898, 0.00699407,
0.03422724, 0.01708906, 0.01789274, 0.02254381, 0.06877233,
0.02917472, 0.01153483, 0.01940942, 0.00970903, 0.03234069,
0.00809172, 0.00611741, 0.00579239, 0.05327223, 0.0172072 ,
0.00812112, 0.0150566 , 0.00717869, 0.00834689, 0.01869914,
0.02305308, 0.03642369]),
'std_test_accuracy_score': array([0.0052071 , 0.00579786, 0.00237952, 0.00237952, 0.00237952,
0.00630598, 0.00237952, 0.00245529, 0.00245529, 0.00237952,
0.00453847, 0.00238676, 0.00237952, 0.00237952, 0.00237952,
0.00237952, 0.00237952, 0.00447766, 0.00237952, 0.00237952,
0.00237952, 0.00237952, 0.00237952, 0.00237952, 0.00438704,
0.00237952, 0.00245529, 0.00237952, 0.00237952, 0.00498089,
0.00237952, 0.00223862]),
'std_test_average_precision_score': array([0.00453488, 0.00472846, 0.00317809, 0.00317809, 0.00317809,
0.00486365, 0.00317809, 0.00337532, 0.00323284, 0.00317809,
0.00356621, 0.00348559, 0.00317809, 0.00317809, 0.00317809,
0.00317809, 0.00317809, 0.00419341, 0.00317809, 0.00317809,
0.00317809, 0.00317809, 0.00317809, 0.00317809, 0.00365881,
0.00317809, 0.00318704, 0.00317809, 0.00317809, 0.00436175,
0.00317809, 0.00342391]),
'std_test_f1_score': array([0.00645852, 0.00684878, 0.00432197, 0.00432197, 0.00432197,
0.00709155, 0.00432197, 0.00458056, 0.00441316, 0.00432197,
0.00508329, 0.00472771, 0.00432197, 0.00432197, 0.00432197,
0.00432197, 0.00432197, 0.00456764, 0.00432197, 0.00432197,
0.00432197, 0.00432197, 0.00432197, 0.00432197, 0.00514523,
0.00432197, 0.00434007, 0.00432197, 0.00432197, 0.00622302,
0.00432197, 0.00463561]),
'std_test_precision_score': array([0.00509482, 0.00575098, 0.00290802, 0.00290802, 0.00290802,
0.00603939, 0.00290802, 0.00306702, 0.00299383, 0.00290802,
0.00401467, 0.00317194, 0.00290802, 0.00290802, 0.00290802,
0.00290802, 0.00290802, 0.00143289, 0.00290802, 0.00290802,
0.00290802, 0.00290802, 0.00290802, 0.00290802, 0.00397313,
0.00290802, 0.00293404, 0.00290802, 0.00290802, 0.00492077,
0.00290802, 0.00307793]),
'std_test_recall_score': array([0.0097477 , 0.00686484, 0.00957997, 0.00957997, 0.00957997,
0.00700376, 0.00957997, 0.01017622, 0.00947923, 0.00957997,
0.00902358, 0.01017622, 0.00957997, 0.00957997, 0.00957997,
0.00957997, 0.00957997, 0.02471501, 0.00957997, 0.00957997,
0.00957997, 0.00957997, 0.00957997, 0.00957997, 0.01014425,
0.00957997, 0.00957997, 0.00957997, 0.00957997, 0.00887569,
0.00957997, 0.01017622]),
'std_test_roc_auc_score': array([0.00586398, 0.00547542, 0.00478331, 0.00478331, 0.00478331,
0.00556611, 0.00478331, 0.00509392, 0.0048349 , 0.00478331,
0.00476394, 0.00520423, 0.00478331, 0.00478331, 0.00478331,
0.00478331, 0.00478331, 0.00834158, 0.00478331, 0.00478331,
0.00478331, 0.00478331, 0.00478331, 0.00478331, 0.00502112,
0.00478331, 0.00478822, 0.00478331, 0.00478331, 0.00559034,
0.00478331, 0.00515529]),
'std_train_accuracy_score': array([0.0028794 , 0.0043955 , 0.00118987, 0.00118987, 0.00118987,
0.00471614, 0.00118987, 0.00119074, 0.00122776, 0.00118987,
0.00119346, 0.00102969, 0.00118987, 0.00118987, 0.00118987,
0.00118987, 0.00118987, 0.00868929, 0.00118987, 0.00115349,
0.00118987, 0.00118987, 0.00118987, 0.00118987, 0.00212179,
0.00118987, 0.00122776, 0.00118987, 0.00118987, 0.00276533,
0.00118987, 0.00114723]),
'std_train_average_precision_score': array([0.00094277, 0.00115721, 0.00157753, 0.00157753, 0.00157753,
0.00116459, 0.00157753, 0.00168944, 0.00167384, 0.00157753,
0.00163377, 0.00176226, 0.00157753, 0.00157753, 0.00157753,
0.00157753, 0.00157753, 0.00675104, 0.00157753, 0.00155668,
0.00157753, 0.00157753, 0.00157753, 0.00157753, 0.00102433,
0.00157753, 0.00169836, 0.00157753, 0.00157753, 0.00099668,
0.00157753, 0.00176381]),
'std_train_f1_score': array([0.00142383, 0.00197178, 0.00215235, 0.00215235, 0.00215235,
0.00211294, 0.00215235, 0.00229382, 0.00228383, 0.00215235,
0.00203871, 0.00238383, 0.00215235, 0.00215235, 0.00215235,
0.00215235, 0.00215235, 0.00958205, 0.00215235, 0.0021186 ,
0.00215235, 0.00215235, 0.00215235, 0.00215235, 0.00125338,
0.00215235, 0.00232279, 0.00215235, 0.00215235, 0.00149459,
0.00215235, 0.00239595]),
'std_train_precision_score': array([0.00158445, 0.00258465, 0.00144986, 0.00144986, 0.00144986,
0.00281118, 0.00144986, 0.00152931, 0.00153159, 0.00144986,
0.00112362, 0.00157476, 0.00144986, 0.00144986, 0.00144986,
0.00144986, 0.00144986, 0.0105987 , 0.00144986, 0.00141837,
0.00144986, 0.00144986, 0.00144986, 0.00144986, 0.00098232,
0.00144986, 0.00156232, 0.00144986, 0.00144986, 0.00158905,
0.00144986, 0.00160053]),
'std_train_recall_score': array([0.00571046, 0.00774184, 0.00478703, 0.00478703, 0.00478703,
0.00743954, 0.00478703, 0.00514585, 0.00508552, 0.00478703,
0.00714267, 0.00514585, 0.00478703, 0.00478703, 0.00478703,
0.00478703, 0.00478703, 0.00527844, 0.00478703, 0.00478703,
0.00478703, 0.00478703, 0.00478703, 0.00478703, 0.00649503,
0.00478703, 0.0050527 , 0.00478703, 0.00478703, 0.00539666,
0.00478703, 0.00514585]),
'std_train_roc_auc_score': array([0.0014449 , 0.00154241, 0.00239057, 0.00239057, 0.00239057,
0.00121566, 0.00239057, 0.00256801, 0.00254607, 0.00239057,
0.00278111, 0.00264964, 0.00239057, 0.00239057, 0.00239057,
0.00239057, 0.00239057, 0.00442714, 0.00239057, 0.00236857,
0.00239057, 0.00239057, 0.00239057, 0.00239057, 0.00194769,
0.00239057, 0.00257353, 0.00239057, 0.00239057, 0.00145944,
0.00239057, 0.0026489 ])}
# Use learning curve to get training and test scores along with train sizes
#
train_sizes, train_scores, test_scores = learning_curve(estimator=svm_model_selection, X=X_train_svm, y=y_train_svm,
cv=3, train_sizes=np.linspace(0.1, 1.0, 3),n_jobs=-1)
#
# Calculate training and test mean and std
#
train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)
#
# Plot the learning curve
plt.plot(train_sizes, train_mean, color='blue', marker='o', markersize=5, label='Training Accuracy')
plt.fill_between(train_sizes, train_mean + train_std, train_mean - train_std, alpha=0.15, color='blue')
plt.plot(train_sizes, test_mean, color='green', marker='+', markersize=5, linestyle='--', label='Validation Accuracy')
plt.fill_between(train_sizes, test_mean + test_std, test_mean - test_std, alpha=0.15, color='green')
plt.title('Learning Curve')
plt.xlabel('Training Data Size')
plt.ylabel('Model accuracy')
plt.grid()
plt.legend(loc='lower right')
plt.show()
# Plot the learning curve
plt.plot(train_sizes, train_mean, color='blue', marker='o', markersize=5, label='Training Accuracy')
plt.fill_between(train_sizes, train_mean + train_std, train_mean - train_std, alpha=0.15, color='blue')
plt.plot(train_sizes, test_mean, color='green', marker='+', markersize=5, linestyle='--', label='Validation Accuracy')
plt.fill_between(train_sizes, test_mean + test_std, test_mean - test_std, alpha=0.15, color='green')
plt.title('Learning Curve (SVM)')
plt.xlabel('Training Data Size')
plt.ylabel('Model accuracy')
plt.grid()
plt.legend(loc='lower right')
plt.show()
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
plt.title("Validation Curve with SVM")
plt.xlabel(r"$\gamma$")
plt.ylabel("Score")
plt.ylim(0.0, 1.1)
lw = 2
plt.semilogx(param_range, train_scores_mean, label="Training score",
color="darkorange", lw=lw)
plt.fill_between(param_range, train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std, alpha=0.2,
color="darkorange", lw=lw)
plt.semilogx(param_range, test_scores_mean, label="Cross-validation score",
color="navy", lw=lw)
plt.fill_between(param_range, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.2,
color="navy", lw=lw)
plt.legend(loc="best")
plt.show()
#svm_best_model = train_svm_best_model(X_train,y_train,svm_model_selection)
df_RandomSearchResults_svm = pd.concat([pd.DataFrame(svm_model_selection.cv_results_["params"]),
pd.DataFrame(svm_model_selection.cv_results_['mean_test_accuracy_score'], columns=['Accuracy']),
pd.DataFrame(svm_model_selection.cv_results_['mean_test_roc_auc_score'],columns=['ROC_AUC']),
pd.DataFrame(svm_model_selection.cv_results_['mean_test_precision_score'],columns=['precision']),
pd.DataFrame(svm_model_selection.cv_results_['mean_test_recall_score'],columns=['recall']),
pd.DataFrame(svm_model_selection.cv_results_['mean_test_average_precision_score'], columns=['PR_AUC'])],axis=1)
df_RandomSearchResults_svm.sort_values(by=['ROC_AUC'],ascending=False).head(32).to_csv('svm_random_search_hyperparameters.csv')
# with open('svm_model_selection.pkl', 'wb') as fid:
# pkl.dump(svm_model_selection, fid)
#files.download('svm_model_selection.pkl')
#files.download('svm_random_search_hyperparameters.csv')
#pd.options.display.max_rows = 100
df_RandomSearchResults_svm = pd.concat([pd.DataFrame(svm_model_selection.cv_results_["params"]),
pd.DataFrame(svm_model_selection.cv_results_['mean_test_accuracy_score'], columns=['Accuracy']),
pd.DataFrame(svm_model_selection.cv_results_['mean_test_roc_auc_score'],columns=['ROC_AUC']),
pd.DataFrame(svm_model_selection.cv_results_['mean_test_average_precision_score'], columns=['PR_AUC'])],axis=1)
df_RandomSearchResults_svm.sort_values(by=['ROC_AUC'],ascending=False)
| svm__C | svm__gamma | svm__kernel | Accuracy | ROC_AUC | PR_AUC | |
|---|---|---|---|---|---|---|
| 17 | 679.954817 | 0.000318 | rbf | 0.785250 | 0.721179 | 0.227452 |
| 1 | 45.075890 | 0.000524 | rbf | 0.739423 | 0.716520 | 0.210983 |
| 5 | 153.868288 | 0.000275 | rbf | 0.738634 | 0.716311 | 0.210680 |
| 29 | 57.254504 | 0.000422 | rbf | 0.732018 | 0.714228 | 0.207898 |
| 0 | 137.678448 | 0.000268 | rbf | 0.731836 | 0.714125 | 0.207801 |
| 24 | 247.277856 | 0.000192 | rbf | 0.729408 | 0.713933 | 0.207138 |
| 10 | 29.940175 | 0.000550 | rbf | 0.727284 | 0.713442 | 0.206400 |
| 11 | 147.004337 | 0.000210 | rbf | 0.720910 | 0.712908 | 0.204779 |
| 31 | 195.455783 | 0.000175 | rbf | 0.720182 | 0.712498 | 0.204394 |
| 7 | 163.676458 | 0.000152 | rbf | 0.719393 | 0.712053 | 0.203981 |
| 9 | 80.020798 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 21 | 28.125665 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 30 | 38.921183 | 0.000150 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 2 | 218.818122 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 28 | 23.455449 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 27 | 423.029375 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 25 | 10.673123 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 3 | 113.367669 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 23 | 68.973709 | 0.000132 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 22 | 4.994681 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 20 | 562.465671 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 19 | 8.950689 | 0.000260 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 18 | 50.733179 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 4 | 2.629374 | 0.000252 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 6 | 121.413095 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 15 | 18.944837 | 0.000205 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 14 | 1.213238 | 0.000149 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 13 | 78.149890 | 0.000124 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 12 | 13.225262 | 0.000172 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 16 | 178.293748 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 26 | 10.529020 | 0.000489 | rbf | 0.719393 | 0.711818 | 0.203841 |
| 8 | 1.783115 | 0.000516 | rbf | 0.719393 | 0.711583 | 0.203705 |
Plot SVM Hyperparameter Results
df_RandomSearchResults_svm
| C | Gamma | Kernel | Accuracy | ROC_AUC | PR_AUC | |
|---|---|---|---|---|---|---|
| 0 | 137.678448 | 0.000268 | 1 | 0.731836 | 0.714125 | 0.207801 |
| 1 | 45.075890 | 0.000524 | 1 | 0.739423 | 0.716520 | 0.210983 |
| 2 | 218.818122 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 3 | 113.367669 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 4 | 2.629374 | 0.000252 | 1 | 0.719454 | 0.711852 | 0.203872 |
| 5 | 153.868288 | 0.000275 | 1 | 0.738634 | 0.716311 | 0.210680 |
| 6 | 121.413095 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 7 | 163.676458 | 0.000152 | 1 | 0.719393 | 0.712053 | 0.203981 |
| 8 | 1.783115 | 0.000516 | 1 | 0.719393 | 0.711583 | 0.203705 |
| 9 | 80.020798 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 10 | 29.940175 | 0.000550 | 1 | 0.727284 | 0.713442 | 0.206400 |
| 11 | 147.004337 | 0.000210 | 1 | 0.720910 | 0.712908 | 0.204779 |
| 12 | 13.225262 | 0.000172 | 1 | 0.719454 | 0.711852 | 0.203872 |
| 13 | 78.149890 | 0.000124 | 1 | 0.719454 | 0.711852 | 0.203872 |
| 14 | 1.213238 | 0.000149 | 1 | 0.719454 | 0.711852 | 0.203872 |
| 15 | 18.944837 | 0.000205 | 1 | 0.719454 | 0.711852 | 0.203872 |
| 16 | 178.293748 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 17 | 679.954817 | 0.000318 | 1 | 0.785250 | 0.721179 | 0.227452 |
| 18 | 50.733179 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 19 | 8.950689 | 0.000260 | 1 | 0.719454 | 0.711852 | 0.203872 |
| 20 | 562.465671 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 21 | 28.125665 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 22 | 4.994681 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 23 | 68.973709 | 0.000132 | 1 | 0.719454 | 0.711852 | 0.203872 |
| 24 | 247.277856 | 0.000192 | 1 | 0.729408 | 0.713933 | 0.207138 |
| 25 | 10.673123 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 26 | 10.529020 | 0.000489 | 1 | 0.719393 | 0.711818 | 0.203841 |
| 27 | 423.029375 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 28 | 23.455449 | 0.000000 | 0 | 0.719454 | 0.711852 | 0.203872 |
| 29 | 57.254504 | 0.000422 | 1 | 0.732018 | 0.714228 | 0.207898 |
| 30 | 38.921183 | 0.000150 | 1 | 0.719454 | 0.711852 | 0.203872 |
| 31 | 195.455783 | 0.000175 | 1 | 0.720182 | 0.712498 | 0.204394 |
Based on ROC Curve
Based on ROC Curve
# Rename columns to a more explainable name
df_RandomSearchResults_svm.rename({'svm__C': 'C', 'svm__gamma': 'Gamma',
'svm__kernel': 'Kernel {0:Linear,1:RBF}'},
axis=1, inplace=True)
df_RandomSearchResults_svm['Gamma'].fillna(0,inplace=True)
df_RandomSearchResults_svm['Kernel {0:Linear,1:RBF}'].replace('linear',0,inplace=True)
df_RandomSearchResults_svm['Kernel {0:Linear,1:RBF}'].replace('rbf',1,inplace=True)
cols = ['C', 'Gamma', 'Kernel {0:Linear,1:RBF}', 'ROC_AUC']
fig = px.parallel_coordinates(df_RandomSearchResults_svm, color='ROC_AUC', dimensions=cols,
color_continuous_scale=px.colors.sequential.Viridis,
title="SVM Hyperparameter Search Plot (ROC_AUC)",
width=1000, height=700)
fig.show()
Based on Accuracy
# Based on Accuracy
df_RandomSearchResults_svm
| C | Gamma | Kernel | Accuracy | ROC_AUC | PR_AUC | |
|---|---|---|---|---|---|---|
| 0 | 137.678448 | 0.000268 | rbf | 0.731836 | 0.714125 | 0.207801 |
| 1 | 45.075890 | 0.000524 | rbf | 0.739423 | 0.716520 | 0.210983 |
| 2 | 218.818122 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 3 | 113.367669 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 4 | 2.629374 | 0.000252 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 5 | 153.868288 | 0.000275 | rbf | 0.738634 | 0.716311 | 0.210680 |
| 6 | 121.413095 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 7 | 163.676458 | 0.000152 | rbf | 0.719393 | 0.712053 | 0.203981 |
| 8 | 1.783115 | 0.000516 | rbf | 0.719393 | 0.711583 | 0.203705 |
| 9 | 80.020798 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 10 | 29.940175 | 0.000550 | rbf | 0.727284 | 0.713442 | 0.206400 |
| 11 | 147.004337 | 0.000210 | rbf | 0.720910 | 0.712908 | 0.204779 |
| 12 | 13.225262 | 0.000172 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 13 | 78.149890 | 0.000124 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 14 | 1.213238 | 0.000149 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 15 | 18.944837 | 0.000205 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 16 | 178.293748 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 17 | 679.954817 | 0.000318 | rbf | 0.785250 | 0.721179 | 0.227452 |
| 18 | 50.733179 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 19 | 8.950689 | 0.000260 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 20 | 562.465671 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 21 | 28.125665 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 22 | 4.994681 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 23 | 68.973709 | 0.000132 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 24 | 247.277856 | 0.000192 | rbf | 0.729408 | 0.713933 | 0.207138 |
| 25 | 10.673123 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 26 | 10.529020 | 0.000489 | rbf | 0.719393 | 0.711818 | 0.203841 |
| 27 | 423.029375 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 28 | 23.455449 | NaN | linear | 0.719454 | 0.711852 | 0.203872 |
| 29 | 57.254504 | 0.000422 | rbf | 0.732018 | 0.714228 | 0.207898 |
| 30 | 38.921183 | 0.000150 | rbf | 0.719454 | 0.711852 | 0.203872 |
| 31 | 195.455783 | 0.000175 | rbf | 0.720182 | 0.712498 | 0.204394 |
# Rename columns to a more explainable name
# df_RandomSearchResults_svm.rename({'svm__C': 'C', 'svm__gamma': 'Gamma',
# 'svm__kernel': 'Kernel'},
# axis=1, inplace=True)
cols = ['C', 'Gamma', 'Kernel {0:Linear,1:RBF}', 'Accuracy']
fig = px.parallel_coordinates(df_RandomSearchResults_svm, color='Accuracy', dimensions=cols,
color_continuous_scale=px.colors.sequential.Viridis,
title="SVM Hyperparameter Search Plot (Accuracy)",
width=1000, height=700)
fig.show()
Analysis of C and Gamma Parameters (Using Accuracy and ROC-AUC)
C_range = np.logspace(-2, 5, 8)
gamma_range = np.logspace(-4, 3, 4)
param_grid = dict(gamma=gamma_range, C=C_range)
scores_roc = svm_model_selection.cv_results_['mean_test_roc_auc_score'].reshape(len(C_range),len(gamma_range))
scores_acc = svm_model_selection.cv_results_['mean_test_accuracy_score'].reshape(len(C_range),len(gamma_range))
min_roc = min(svm_model_selection.cv_results_['mean_test_roc_auc_score'])
max_roc = max(svm_model_selection.cv_results_['mean_test_roc_auc_score'])
mid_roc = min_roc + ((max_roc - min_roc)/2)
min_acc = min(svm_model_selection.cv_results_['mean_test_accuracy_score'])
max_acc = max(svm_model_selection.cv_results_['mean_test_accuracy_score'])
mid_acc = min_acc + ((max_acc - min_acc)/2)
fig = plt.figure(figsize=(14, 8)) # create the canvas for plotting
ax1 = plt.subplot(1,2,1)
#ax[0] = plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
im1 = ax1.imshow(scores_roc, interpolation='nearest', cmap=plt.cm.viridis,
norm=MidpointNormalize(vmin=min_roc, vmax=max_roc, midpoint=mid_roc))
ax1.set_xlabel('gamma')
ax1.set_ylabel('C')
fig.colorbar(im1)
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
plt.yticks(np.arange(len(C_range)), C_range)
ax1.set_title('Validation ROC AUC')
ax2 = plt.subplot(1,2,2)
#ax[0] = plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
im2 = ax2.imshow(scores_acc, interpolation='nearest', cmap=plt.cm.viridis,
norm=MidpointNormalize(vmin=min_acc, vmax=max_acc, midpoint=mid_acc))
ax2.set_xlabel('gamma')
ax2.set_ylabel('C')
fig.colorbar(im2)
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
plt.yticks(np.arange(len(C_range)), C_range)
ax2.set_title('Validation Accuracy')
plt.show()
# Save the model when finds the best train and validation losses
def train_svm_best_model(X_train_model,y_train_model,model):
# svm_C = model.best_params_['svm__C']
# svm_gamma = model.best_params_['svm__gamma']
# svm_degree = model.best_params_['svm__degree']
# svm_kernel = model.best_params_['svm__kernel']
weights = 'balanced'
svm_retrain = SVC(probability=True, class_weight = weights, random_state=2)
scaler = preprocessing.StandardScaler()
over = SMOTE(sampling_strategy=0.2, random_state=2 ,k_neighbors=7)
#under = EditedNearestNeighbours(sampling_strategy='majority', n_neighbors=7)
rand_under = RandomUnderSampler(sampling_strategy='majority', random_state=2)
# Define the Imbalanced Pipeline with SMOTE and Random Under Sampling
svm_pipeline_retrain = imbPipeline([('scaler',scaler),
('o', over), #('u', under),
('ru', rand_under),
('svm', svm_retrain)])
# The command below (**model.best_params_) will copy the hyperparameters from the
# best model to the svm model in the pipeline
svm_pipeline_retrain.set_params(**model.best_params_)
starttime = time.time()
svm_best_model = svm_pipeline_retrain.fit(X_train_model, y_train_model.ravel())
totaltime = time.time() - starttime
print("Final model train (SVM) took %.2f seconds (%.2f hours) to train on the entire training data" % ((totaltime), (totaltime/3600)))
return svm_best_model
svm_best_model = train_svm_best_model(X_train,y_train,svm_model_selection)
# save the classifier
# with open('svm_best_model.pkl', 'wb') as fid:
# pkl.dump(svm_best_model, fid)
Final model train (SVM) took 66.86 seconds (0.02 hours) to train on the entire training data
# dump(svm_model_selection,'svm_model_selection.joblib')
# dump(svm_best_model,'svm_best_model.joblib')
['svm_best_model.joblib']
We will analyse the Precision-Recall curve over the training set using the best model (Training over the entire training data)
y_pred_proba_svm = svm_best_model.predict_proba(X_train.astype(np.float32))[:,1]
precision, recall, thresholds = precision_recall_curve(y_train, y_pred_proba_svm)
precision_recall_threshold(precision, recall, thresholds, y_train, y_pred_proba_svm, t=0.5, plot_type='SVM',savefig='No')
We will analyse the Precision-Recall curve over the training set using the best model (Training over the entire training data)
y_pred_svm = svm_best_model.predict(X_test_unscaled)
conf_matrix = confusion_matrix(y_test.ravel(), y_pred_svm)
plot_confusion_matrix(conf_matrix, target_names= ['Not Subscribed to term deposit', 'Subscribed to term deposit'],
plot_type='svm','No')
# Save to png
Now we want to analyse how the model can be compared against the training data:
True Positives: Total number of correct predictions of a client subscription.
False Positives: The model predicted a subscription but the client did not subscribe.
True Negatives: Total number of correct predictions of clients not subscribed.
False Negative: The model predicted a non subscription but the client is subscribed.
From the above definitions we can clearly state that False Positives can generate less benefits compared to False Negatives. Therefore it is more important to have less False Positives than the opposite.
F1-Score: Support: Each class has a support number to represent the amount of examples for the class.
#gs.fit(X_test_net.astype(np.float32), y_test_net.astype(np.float32).squeeze(1))
y_train_preds = svm_best_model.predict(X_train.astype(np.float32))
actual_label = y_train.astype(np.float32).squeeze(1)
f1 = f1_score(actual_label, y_train_preds)
#fbeta_score = fbeta_score(actual_label, y_train_preds, average='weighted', beta=0.5)
accuracy = accuracy_score(actual_label, y_train_preds)
roc_auc = roc_auc_score(actual_label, y_train_preds)
cm = confusion_matrix(actual_label, y_train_preds)
report = classification_report(actual_label, y_train_preds)
print("Model ROC-AUC(Train Data): ", roc_auc)
print("Model F1-Score (Train Data): ", f1)
#print("Model FBeta-Score (Train Data): ", fbeta_score)
print("Model Accuracy: ", accuracy)
print("Confusion Matrix:\n", cm)
print("\nClassification Report:\n", report)
Model ROC-AUC(Train Data): 0.7341615964315451
Model F1-Score (Train Data): 0.41656744344276236
Model Accuracy: 0.7918057663125948
Confusion Matrix:
[[23641 5597]
[ 1263 2449]]
Classification Report:
precision recall f1-score support
0.0 0.95 0.81 0.87 29238
1.0 0.30 0.66 0.42 3712
accuracy 0.79 32950
macro avg 0.63 0.73 0.64 32950
weighted avg 0.88 0.79 0.82 32950
yhat_svm = svm_best_model.predict_proba(X_train.astype(np.float32))
# keep probabilities for the positive outcome only
yhat_svm = yhat_svm[:, 1]
# calculate roc curves
fpr, tpr, thresholds = roc_curve(y_train, yhat_svm)
plot_roc_auc_thresholds(fpr, tpr, thresholds, y_train, yhat_svm, plot_type='SVM')
# calculate the g-mean for each threshold
gmeans = sqrt(tpr * (1-fpr))
# locate the index of the largest g-mean
ix = argmax(gmeans)
print('Best Threshold=%f, G-Mean=%.3f' % (thresholds[ix], gmeans[ix]))
# plot the roc curve for the model
plt.plot([0,1], [0,1], linestyle='--', label='No Skill')
plt.plot(fpr, tpr, marker='.', label='Logistic')
plt.scatter(fpr[ix], tpr[ix], marker='o', color='black', label='Best')
# axis labels
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend()
# show the plot
plt.show()
Best Threshold=0.472036, G-Mean=0.731
# Data to plot precision - recall curve
precision, recall, thresholds = precision_recall_curve(y_train, yhat_svm)
# Use AUC function to calculate the area under the curve of precision recall curve
auc_precision_recall = auc(recall, precision)
print('PR-AUC:',auc_precision_recall)
PR-AUC: 0.39958878639923034
precision_recall_threshold(precision, recall, thresholds, y_train, yhat_svm, t=0.5, plot_type='SVM',savefig='No')
findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
"""
Modified from:
Hands-On Machine learning with Scikit-Learn
and TensorFlow; p.89
"""
plt.figure(figsize=(8, 8))
plt.title("Precision and Recall Scores as a function of the decision threshold")
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
plt.ylabel("Score")
plt.xlabel("Decision Threshold")
plt.legend(loc='best')
plot_precision_recall_vs_threshold(precision, recall, thresholds)
Confusion Matrix of the SVM model selection (Training with cross-validation over 50% of the training sample)
y_pred_svm = svm_model_selection_pred.predict(X_train)
conf_matrix = confusion_matrix(y_train.ravel(), y_pred_svm)
plot_confusion_matrix(conf_matrix, target_names=['Not Subscribed', 'Subscribed'],plot_type='svm')
Confusion Matrix of the best selected SVM model (Training with entire training data)
y_pred_svm = svm_best_model.predict(X_train)
conf_matrix = confusion_matrix(y_train.ravel(), y_pred_svm)
plot_confusion_matrix(conf_matrix, target_names= ['Not Subscribed', 'Subscribed'],plot_type='svm')
# from google.colab import files
# files.download('svm_best_model.pkl')
# Training Data
# save the classifier
with open('svm_best_model.pkl', 'wb') as fid:
pkl.dump(svm_best_model, fid)
y_score_svm = svm_model_selection.fit(X_train_svm, y_train_svm).decision_function(X_test_net)
n_classes = 2
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test_net[:, i], y_score_svm[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test_net.ravel(), y_score_svm.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
Due to the large size of the training set the SVM algorithm can take a long time to complete the hyperparameter search. The fit time scales at least quadratically with the number of samples and may be impractical beyond 10000 samples.
svm_model_selection.best_estimator_
Pipeline(memory=None,
steps=[('scaler',
StandardScaler(copy=True, with_mean=True, with_std=True)),
('o',
SMOTE(k_neighbors=7, kind='deprecated',
m_neighbors='deprecated', n_jobs=1,
out_step='deprecated', random_state=2, ratio=None,
sampling_strategy=0.2, svm_estimator='deprecated')),
('ru',
RandomUnderSampler(random_state=2, ratio=None,
replacement=False, return_indices=False,
sampling_strategy='majority')),
('svm',
SVC(C=467.30141794802967, break_ties=False, cache_size=200,
class_weight=None, coef0=0.0,
decision_function_shape='ovr', degree=3,
gamma=0.0005628319006580491, kernel='rbf', max_iter=-1,
probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False))],
verbose=False)
#X_train_svm_scaled = scaler.StandardScaler()
#scaler = StandardScaler()
#X_train_svm_scaled = scaler.fit_transform(X_train_net.astype(np.float32))
#X_test_svm_scaled = scaler.transform(X_test_net.astype(np.float32))
# svm_grid.best_estimator_.fit(X_train_net, y_train_net)
# #y_pred = svm_grid.best_estimator_.predict(X_train_svm_scaled)
# #accuracy_score(y_train_net, y_pred)
y_true, y_pred = y_train_svm, svm_model_selection.predict(X_train_svm)
print(classification_report(y_true, y_pred))
precision recall f1-score support
0 0.95 0.80 0.87 8771
1 0.30 0.68 0.42 1114
accuracy 0.79 9885
macro avg 0.63 0.74 0.65 9885
weighted avg 0.88 0.79 0.82 9885
print("Predicting people's names on the test set")
t0 = time()
y_pred_svm = clf.predict(X_test_net)
print("done in %0.3fs" % (time() - t0))
target_names = ['Yes','No']
print(classification_report(y_test_net, y_pred_svm, target_names=target_names))
print(confusion_matrix(y_test, y_pred, labels=range(n_classes)))
Predicting people's names on the test set
done in 4.578s
precision recall f1-score support
Yes 0.93 0.93 0.93 7303
No 0.44 0.41 0.43 935
accuracy 0.87 8238
macro avg 0.68 0.67 0.68 8238
weighted avg 0.87 0.87 0.87 8238
--------------------------------------------------------------------------- NameError Traceback (most recent call last) <ipython-input-37-801e84541562> in <module> 5 target_names = ['Yes','No'] 6 print(classification_report(y_test_net, y_pred_svm, target_names=target_names)) ----> 7 print(confusion_matrix(y_test, y_pred, labels=range(n_classes))) NameError: name 'y_pred' is not defined
clf = SVC(kernel='rbf', C=1E6)
clf.fit(X, y)
y
array([3, 7, 7, ..., 5, 9, 5], dtype=int64)
# Load SVM model selection
svm_model_selection_filename = 'svm_model_selection.joblib'
svm_model_selection = load(svm_model_selection_filename)
# with open('svm_best_model.pkl', 'rb') as handle:
# svm_model_selection_model = pkl.load(handle)
Saving test set for prediction on the saved models
# Scale and fit using train distribution and apply scaler to transform test features
#scaler = preprocessing.StandardScaler()
#X_train_scaled = scaler.fit_transform(X_train)
#X_test_scaled = scaler.transform(X_test)
df_X_test = pd.DataFrame(data=X_test,columns=test_df_trans.columns)
df_y_test = pd.DataFrame(data=y_test,columns=['y'])
frames_test = [df_X_test,df_y_test]
df_test = pd.concat(frames_test, axis=1)
df_test.to_csv('test.csv',index=False)
df_test.to_csv('test.csv',index=False)
#torch.save(bestmodel.state_dict(), PATH)
# saving
# with open('best_model_MLP.pkl', 'wb') as f:
# pkl.dump(rs, f)
nnet.save_params(f_params='best_model_MLP.pkl')
#torch.save(bestmodel.state_dict(), PATH)
# saving
with open('best_model_SVM.pkl', 'wb') as f:
pkl.dump(rs_SVM, f)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) <ipython-input-576-4e15ed4cb30e> in <module> 3 # saving 4 with open('best_model_SVM.pkl', 'wb') as f: ----> 5 pkl.dump(rs_SVM, f) NameError: name 'rs_SVM' is not defined
class AdaBound(Optimizer):
""" AdaBound code from https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py
Implements AdaBound algorithm.
It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): Adam learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
final_lr (float, optional): final (SGD) learning rate (default: 0.1)
gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
.. Adaptive Gradient Methods with Dynamic Bound of Learning Rate:
https://openreview.net/forum?id=Bkg3g2R9FX
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3,
eps=1e-8, weight_decay=0, amsbound=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= final_lr:
raise ValueError("Invalid final learning rate: {}".format(final_lr))
if not 0.0 <= gamma < 1.0:
raise ValueError("Invalid gamma parameter: {}".format(gamma))
defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps,
weight_decay=weight_decay, amsbound=amsbound)
super(AdaBound, self).__init__(params, defaults)
self.base_lrs = list(map(lambda group: group['lr'], self.param_groups))
def __setstate__(self, state):
super(AdaBound, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsbound', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group, base_lr in zip(self.param_groups, self.base_lrs):
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'Adam does not support sparse gradients, please consider SparseAdam instead')
amsbound = group['amsbound']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsbound:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsbound:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsbound:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
# Applies bounds on actual learning rate
# lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
final_lr = group['final_lr'] * group['lr'] / base_lr
lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))
upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step']))
step_size = torch.full_like(denom, step_size)
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
p.data.add_(-step_size)
return loss
# Transform input variables
def transform_input(X_train, X_test, quant_var, nom_var, onehot_var, df):
# Reconstruct features set
X_train_df = pd.DataFrame(X_train, columns = df.drop('y', axis=1).columns)
X_test_df = pd.DataFrame(X_test, columns = df.drop('y', axis=1).columns)
# Define quantitative variables
train_df_quant = X_train_df[quant_var]
test_df_quant = X_test_df[quant_var]
# Define Ordinal Encoder for Default, Housing, Loan variables
ordn = OrdinalEncoder(categories=[['yes','unknown','no']]*3)
# Fit
ordn.fit(X_train_df[['default','housing', 'loan']]) #train_df_ordinal[['default','housing', 'loan']])
# Transform
X_train_ord1 = ordn.transform(X_train_df[['default','housing', 'loan']])
X_test_ord1 = ordn.transform(X_test_df[['default','housing', 'loan']])
# Define Ordinal Encoder for Poutcome variable
ordn = OrdinalEncoder(categories=[['failure','nonexistent','success']])
# Fit
ordn.fit(np.array(X_train_df['poutcome']).reshape(-1,1))
# Transform
X_train_ord2 = ordn.transform(X_train_df[['poutcome']])
X_test_ord2 = ordn.transform(X_test_df[['poutcome']])
# Define Ordinal Encoder for Contact variable
ordn = OrdinalEncoder(categories=[['cellular','telephone']])
# Fit
ordn.fit(np.array(X_train_df['contact']).reshape(-1,1))
# Transform
X_train_ord3 = ordn.transform(X_train_df[['contact']])
X_test_ord3 = ordn.transform(X_test_df[['contact']])
# Define Ordinal Encoder for Contact variable
ordn = OrdinalEncoder(categories=[['mon','tue','wed','thu','fri']])
# Fit
ordn.fit(np.array(X_train_df['day_of_week']).reshape(-1,1))
# Transform
X_train_ord4 = ordn.transform(X_train_df[['day_of_week']])
X_test_ord4 = ordn.transform(X_test_df[['day_of_week']])
# Define Ordinal Encoder for Month variable
ordn = ce.OrdinalEncoder(cols=['month'], return_df=True, mapping = [{
'col': 'month', 'mapping': {
'jan': 1, 'feb': 2, 'mar': 3, \
'apr': 4, 'may': 5, 'jun': 6, 'jul': 7, \
'aug': 8, 'sep': 9, 'oct': 10, 'nov': 11, 'dec': 12}}])
# Fit
ordn.fit(X_train_df['month'])
# Transform
X_train_ord5 = ordn.transform(X_train_df[['month']])
X_test_ord5 = ordn.transform(X_test_df[['month']])
# Define OneHot Encoder
ohe = OneHotEncoder(handle_unknown='error') #,drop='first',
# Fit
ohe.fit(X_train_df[nom_var])
onehot_list = ohe.get_feature_names(nom_var)
# Transform
X_train_nom = ohe.transform(X_train_df[nom_var])
X_test_nom = ohe.transform(X_test_df[nom_var])
# Create dataframe (training)
train_df_ord1 = pd.DataFrame(X_train_ord1, columns = ['default','housing', 'loan'])
train_df_ord2 = pd.DataFrame(X_train_ord2, columns = ['poutcome'])
train_df_ord3 = pd.DataFrame(X_train_ord3, columns = ['contact'])
train_df_ord4 = pd.DataFrame(X_train_ord4, columns = ['day_of_week'])
train_df_ord5 = pd.DataFrame(X_train_ord5, columns = ['month'])
train_df_nom = pd.DataFrame(X_train_nom.toarray(), columns = list(onehot_list))
# Create dataframe (test)
test_df_ord1 = pd.DataFrame(X_test_ord1, columns = ['default','housing', 'loan'])
test_df_ord2 = pd.DataFrame(X_test_ord2, columns = ['poutcome'])
test_df_ord3 = pd.DataFrame(X_test_ord3, columns = ['contact'])
test_df_ord4 = pd.DataFrame(X_test_ord4, columns = ['day_of_week'])
test_df_ord5 = pd.DataFrame(X_test_ord5, columns = ['month'])
test_df_nom = pd.DataFrame(X_test_nom.toarray(), columns = list(onehot_list))
# Concatenate dataframes (quantitative / categorical) for training / test
train_frames = [train_df_ord1, train_df_ord2, train_df_ord3, train_df_ord4,
train_df_ord5, train_df_nom, train_df_quant]
test_frames = [test_df_ord1, test_df_ord2, test_df_ord3, test_df_ord4,
test_df_ord5, test_df_nom, test_df_quant]
train_df_trans = pd.concat(train_frames, axis=1)
test_df_trans = pd.concat(test_frames, axis=1)
#print(X_train_trans, X_test_trans)
return train_df_trans, test_df_trans
# Transform target variable
def transform_target(y_train, y_test):
le = LabelEncoder()
le.fit(np.ravel(y_train))
y_train_trans = le.transform(np.ravel(y_train))
y_test_trans = le.transform(np.ravel(y_test))
return y_train_trans, y_test_trans
def pre_process():
df = pd.read_csv('bank-additional-full.csv', delimiter=';')
df.drop(columns=['duration'],inplace=True,errors='ignore')
quant_var = ['age','campaign','pdays','previous','emp.var.rate',
'cons.price.idx','cons.conf.idx','euribor3m','nr.employed']
categ_var = ['default','housing','loan','poutcome','contact','job',
'marital','education','month','day_of_week']
# Create dummy variables
nom_var = ['job','marital','education']
df_nom_var = pd.get_dummies(df[nom_var], columns=nom_var, drop_first=True)
# Save name of the onehot columns
onehot_var = df_nom_var.columns
# Define features and target variables accordingly
X = df.drop('y', axis=1).values
y = df['y'].values
y = y.reshape(y.shape[0], 1)
X_train_raw, X_test_raw, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)
# Transform input data (X)
train_df_trans, test_df_trans = transform_input(X_train_raw, X_test_raw, quant_var, nom_var, onehot_var, df)
# Transform output data (y)
y_train_trans, y_test_trans = transform_target(y_train, y_test)
# Prepare inputs
X_train = np.array(train_df_trans)
X_test = np.array(test_df_trans)
# Reshape target variable
y_train = y_train_trans.reshape(-1,1)
y_test = y_test_trans.reshape(-1,1)
# Scale variable
#scaler = preprocessing.StandardScaler()
#X_train = scaler.fit_transform(X_train)
#X_test = scaler.transform(X_test)
return X_train, y_train
def mlp_train_and_model_selection(X_train_model,y_train_model,n_iter=10):
scaler = preprocessing.StandardScaler()
# Define resampling technique
#smote_only = SMOTE(random_state=2) #sampling_strategy='minority'
#under = EditedNearestNeighbours(sampling_strategy='majority', n_neighbors=7)
over = SMOTE(sampling_strategy=0.2, random_state=2 ,k_neighbors=7)
rand_under = RandomUnderSampler(sampling_strategy='majority', random_state=2)
# Create a Imbalance Pipeline with Over Sampling and Under Sampling
nnet_pipeline = imbPipeline([('scaler',scaler),
#('smoteonly', smote_only),
('o', over), #('u', under),
('ru', rand_under),
('nnet', nnet)])
params_randcv ={
'nnet__batch_size': [32],
'nnet__module__hidden_dim':randint(20,80),
'nnet__module__dropout': [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], #uniform(0, 1),
'nnet__lr': loguniform(1e-5, 1),
'nnet__optimizer__weight_decay': loguniform(1e-6, 1e-1),
'nnet__optimizer': [AdaBound, optim.Adam], #, optim.RMSprop, optim.Adagrad, optim.Adadelta, optim.SGD, optim.Adamax],
'nnet__max_epochs': randint(50,350)
}
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2)
scorersMLP = {
'precision_score': make_scorer(precision_score, zero_division=1),
'recall_score': make_scorer(recall_score, zero_division=1),
'accuracy_score': make_scorer(accuracy_score),
'roc_auc_score': make_scorer(roc_auc_score),
'average_precision_score': make_scorer(average_precision_score),
'f1_score': make_scorer(f1_score)
}
nnet_pipeline.fit(X_train_model.astype(np.float32), y_train_model.astype(np.int64).squeeze(1))
#y_proba = nnet_pipeline.predict_proba(X_train_model.astype(np.float32))
mlp_rs = RandomizedSearchCV(nnet_pipeline, params_randcv, refit='roc_auc_score',
cv=skf, scoring=scorersMLP, return_train_score = True,
n_iter=n_iter, random_state=123, n_jobs=-1,
verbose=100)
start = time.time()
mlp_model_selection = mlp_rs.fit(X_train_model.astype(np.float32), y_train_model.astype(np.int64).squeeze(1))
totaltime = time.time() - start
print("RandomizedSearchCV (MLP) took %.2f seconds (%.2f hours) for %d selected candidates"
" parameter settings." % ((totaltime), (totaltime/3600), n_iter))
print("Best params: {}".format(mlp_model_selection.best_params_))
print("Best scores: {}".format(mlp_model_selection.best_score_))
return mlp_model_selection, nnet
def mlp_save_best_hyperparam(model):
nnet_batch_size = model.best_params_['nnet__batch_size']
nnet_lr = mlp_best_model.best_params_['nnet__lr']
nnet_module_dropout = model.best_params_['nnet__module__dropout']
nnet_module_hidden_dim = model.best_params_['nnet__module__hidden_dim']
nnet_optimizer_weight_decay = model.best_params_['nnet__optimizer__weight_decay']
nnet_optimizer = model.best_params_['nnet__optimizer']
nnet_max_epochs = model.best_params_['nnet__max_epochs']
dict_mlp_best_hyp = {'batch_size': nnet_batch_size,
'lr': nnet_lr,
'dropout': nnet_module_dropout,
'hidden_dim': nnet_module_hidden_dim,
'weight_decay': nnet_optimizer_weight_decay,
'optimizer': nnet_optimizer,
'max_epochs': nnet_max_epochs}
# Saving MLP best Hyperparameters
with open('mlp_best_hyperparam.pkl', 'wb') as f:
pkl.dump(dict_mlp_best_hyp, f)
def train_mlp_final_model(X_train_model,y_train_model,model,device):
nnet_batch_size = model.best_params_['nnet__batch_size']
nnet_lr = model.best_params_['nnet__lr']
nnet_module_dropout = model.best_params_['nnet__module__dropout']
nnet_module_hidden_dim = model.best_params_['nnet__module__hidden_dim']
nnet_optimizer_weight_decay = model.best_params_['nnet__optimizer__weight_decay']
nnet_optimizer = model.best_params_['nnet__optimizer']
nnet_max_epochs = model.best_params_['nnet__max_epochs']
scaler = preprocessing.StandardScaler()
# over = SMOTE(sampling_strategy=0.2, random_state=2 ,k_neighbors=7)
# rand_under = RandomUnderSampler(sampling_strategy='majority', random_state=2)
#monitor_losses = lambda net: all(net.history[-1, ('train_loss_best',)])
model = NeuralNet(hidden_dim=nnet_module_hidden_dim, dropout=nnet_module_dropout)
model.to(device)
# Define new instance to train
n_net_retrain = NeuralNetClassifier(
model,
max_epochs=nnet_max_epochs,
batch_size=nnet_batch_size,
criterion=nn.CrossEntropyLoss,
lr=nnet_lr,
#callbacks=[Checkpoint(monitor=monitor_losses)], # Save best train loss
iterator_train__shuffle=True, # Shuffle training data on each epoch
optimizer__weight_decay=nnet_optimizer_weight_decay,
optimizer=nnet_optimizer,
train_split=None, #Disable skorch validation split so we use the entire test set
device=device
)
# Define the Imbalanced Pipeline without SMOTE and Random Under Sampling
nnet_pipeline_retrain = imbPipeline([('scaler',scaler),
('nnet', n_net_retrain)])
# Train the model using all training data
mlp_final_model = nnet_pipeline_retrain.fit(X_train_model.astype(np.float32), y_train_model.astype(np.int64).squeeze(1))
mlp_save_best_hyperparam(mlp_final_model)
# Save best model trained on entire training set
nnet_retrain.save_params(f_params='mlp_best_model.pkl', f_optimizer='mlp_opt.pkl', f_history='mlp_history.json')
return mlp_final_model
def svm_train_and_model_selection(X_train_model, y_train_model, n_iter = 10):
scaler = preprocessing.StandardScaler()
# Define resampling technique
over = SMOTE(sampling_strategy=0.2, k_neighbors=7, random_state=2)
#under = EditedNearestNeighbours(sampling_strategy='majority', n_neighbors=7)
rand_under = RandomUnderSampler(sampling_strategy='majority', random_state=2)
# define model weights
weights = 'balanced'
#weights = {0:12.0, 1:88.0}
# Create a Imbalance Pipeline with Over Sampling and Under Sampling
svm_smp_pipeline = imbPipeline([('scaler',scaler),
('o', over), ('ru', rand_under),
#('u', under),
('svm', SVC(class_weight=weights,random_state=2,probability =True))])
# Define Hyperparameters Space
svm_params = [{'svm__kernel': ['rbf'], 'svm__gamma': loguniform(1e-4, 1e-3), 'svm__C': loguniform(1e0, 1e3)},
#{'svm__kernel': ['poly'], 'svm__degree': [2, 3, 4, 5]},
{'svm__kernel': ['linear'], 'svm__C': loguniform(1e0, 1e3)}]
scorersSVM = {
'precision_score': make_scorer(precision_score, zero_division=1),
'recall_score': make_scorer(recall_score, zero_division=1),
'accuracy_score': make_scorer(accuracy_score),
'roc_auc_score': make_scorer(roc_auc_score,average='weighted'),
'average_precision_score': make_scorer(average_precision_score),
'f1_score': make_scorer(f1_score)
}
skf = StratifiedKFold(n_splits=5,shuffle=True, random_state=2)
svm_rs = RandomizedSearchCV(svm_smp_pipeline, svm_params, refit='roc_auc_score', cv=skf, scoring=scorersSVM,
n_iter=n_iter, random_state=123, n_jobs=-1, return_train_score=True)
starttime = time.time()
svm_model_selection = svm_rs.fit(X_train_model, y_train_model.ravel())
totaltime = time.time() - starttime
print("SVM (RandomSearch) took %.2f seconds for %d CV folds." % ((totaltime), 5))
print("best score: {}, best params: {}".format(svm_rs.best_params_, svm_rs.best_params_))
return svm_model_selection
# Save the model when finds the best train and validation losses
def train_svm_best_model(X_train_model,y_train_model,model):
# svm_C = model.best_params_['svm__C']
# svm_gamma = model.best_params_['svm__gamma']
# svm_degree = model.best_params_['svm__degree']
# svm_kernel = model.best_params_['svm__kernel']
weights = 'balanced'
svm_retrain = SVC(probability=True, class_weight = weights, random_state=2)
scaler = preprocessing.StandardScaler()
over = SMOTE(sampling_strategy=0.2, random_state=2 ,k_neighbors=7)
#under = EditedNearestNeighbours(sampling_strategy='majority', n_neighbors=7)
rand_under = RandomUnderSampler(sampling_strategy='majority', random_state=2)
# Define the Imbalanced Pipeline with SMOTE and Random Under Sampling
svm_pipeline_retrain = imbPipeline([('scaler',scaler),
('o', over), #('u', under),
('ru', rand_under),
('svm', svm_retrain)])
# The command below (**model.best_params_) will copy the hyperparameters from the
# best model to the svm model in the pipeline
svm_pipeline_retrain.set_params(**model.best_params_)
starttime = time.time()
svm_best_model = svm_pipeline_retrain.fit(X_train_model, y_train_model.ravel())
totaltime = time.time() - starttime
print("Final model train (SVM) took %.2f seconds (%.2f hours) to train on the entire training data" % ((totaltime), (totaltime/3600)))
dump(svm_best_model,'svm_best_model.joblib')
return svm_best_model
Example for MLP reproducibility using 64 iterations for RandomSearch (Uncomment for training)
#X_train, y_train = pre_process()
# Train MLP models with 64 hyperparameter random search, save results to csv file and return the model selection
#mlp_selected_model = mlp_train_and_model_selection(X_train, y_train, 64)
# Train best model, generate pkl output file in root folder and return the model selection
#mlp_final_model = mlp_train_best_model(mlp_selected_model)
Example for SVM reproducibility using 32 iterations for RandomSearch (Uncomment for training)
#X_train, y_train = pre_process()
# Train SVM models with 32 hyperparameter random search, save results to csv file and return the model selection
#svm_selected_model = svm_train_and_model_selection(X_train, y_train, 32)
# Train best model and generate joblib output file in root folder
#svm_best_model = svm_train_best_model(svm_selected_model)